lora-scripts/mikazuki/utils.py

131 lines
3.8 KiB
Python

import glob
import importlib.util
import os
import subprocess
import sys
import re
import shutil
from typing import Optional
from mikazuki.log import log
python_bin = sys.executable
def validate_data_dir(path):
if not os.path.exists(path):
log.error(f"Data dir {path} not exists, check your params")
return False
dir_content = os.listdir(path)
if len(dir_content) == 0:
log.error(f"Data dir {path} is empty, check your params")
subdirs = [f for f in dir_content if os.path.isdir(os.path.join(path, f))]
if len(subdirs) == 0:
log.warn(f"No subdir found in data dir")
ok_dir = [d for d in subdirs if re.findall(r"^\d+_.+", d)]
if len(ok_dir) == 0:
log.warning(f"No leagal dataset found. Try find avaliable images")
imgs = get_total_images(path, False)
log.info(f"{len(imgs)} images found")
if len(imgs) > 0:
dataset_path = os.path.join(path, "1_zkz")
os.makedirs(dataset_path)
for i in imgs:
shutil.move(i, dataset_path)
log.info(f"Auto dataset created {dataset_path}")
else:
log.error("No image found in data dir")
return False
return True
def check_training_params(data):
potential_path = [
"train_data_dir", "reg_data_dir", "output_dir"
]
file_paths = [
"sample_prompts"
]
for p in potential_path:
if p in data and not os.path.exists(data[p]):
return False
for f in file_paths:
if f in data and not os.path.exists(data[f]):
return False
return True
def get_total_images(path, recursive=True):
if recursive:
image_files = glob.glob(path + '/**/*.jpg', recursive=True)
image_files += glob.glob(path + '/**/*.jpeg', recursive=True)
image_files += glob.glob(path + '/**/*.png', recursive=True)
else:
image_files = glob.glob(path + '/*.jpg')
image_files += glob.glob(path + '/*.jpeg')
image_files += glob.glob(path + '/*.png')
return image_files
def is_installed(package):
try:
spec = importlib.util.find_spec(package)
except ModuleNotFoundError:
return False
return spec is not None
def run(command,
desc: Optional[str] = None,
errdesc: Optional[str] = None,
custom_env: Optional[list] = None,
live: Optional[bool] = True,
shell: Optional[bool] = None):
if shell is None:
shell = False if sys.platform == "win32" else True
if desc is not None:
print(desc)
if live:
result = subprocess.run(command, shell=shell, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
raise RuntimeError(f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}""")
return ""
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
shell=shell, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
message = f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout) > 0 else '<empty>'}
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0 else '<empty>'}
"""
raise RuntimeError(message)
return result.stdout.decode(encoding="utf8", errors="ignore")
def run_pip(command, desc=None, live=False):
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run(file: str) -> bool:
result = subprocess.run([python_bin, file], capture_output=True, shell=False)
log.info(result.stdout.decode("utf-8").strip())
return result.returncode == 0