feat: auto check & create dataset folder

pull/222/head
akiba 2023-09-01 17:20:57 +08:00
parent 7603fb87f9
commit 5dab6a87f6
No known key found for this signature in database
GPG Key ID: 9D600258808ACBCD
2 changed files with 51 additions and 12 deletions

View File

@ -89,13 +89,8 @@ async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
toml_data = await request.body() toml_data = await request.body()
j = json.loads(toml_data.decode("utf-8")) j = json.loads(toml_data.decode("utf-8"))
# ok = utils.check_training_params(j) utils.validate_data_dir(j["train_data_dir"])
# if not ok: suggest_cpu_threads = 8 if len(utils.get_total_images(j["train_data_dir"])) > 100 else 2
# lock.release()
# print("训练目录校验失败,请确保填写的目录存在")
# return {"status": "fail", "detail": "训练目录校验失败,请确保填写的目录存在"}
suggest_cpu_threads = 8 if utils.get_total_images(j["train_data_dir"]) > 100 else 2
trainer_file = "./sd-scripts/train_network.py" trainer_file = "./sd-scripts/train_network.py"
if j.pop("model_train_type", "sd-lora") == "sdxl-lora": if j.pop("model_train_type", "sd-lora") == "sdxl-lora":
@ -117,6 +112,8 @@ async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
j["sample_prompts"] = sample_prompts_file j["sample_prompts"] = sample_prompts_file
log.info(f"Writted promopts to file {sample_prompts_file}") log.info(f"Writted promopts to file {sample_prompts_file}")
with open(toml_file, "w") as f: with open(toml_file, "w") as f:
f.write(toml.dumps(j)) f.write(toml.dumps(j))

View File

@ -3,12 +3,48 @@ import importlib.util
import os import os
import subprocess import subprocess
import sys import sys
import re
import shutil
from typing import Optional from typing import Optional
from mikazuki.log import log from mikazuki.log import log
python_bin = sys.executable python_bin = sys.executable
def validate_data_dir(path):
if not os.path.exists(path):
log.error(f"Data dir {path} not exists, check your params")
return False
dir_content = os.listdir(path)
if len(dir_content) == 0:
log.error(f"Data dir {path} is empty, check your params")
subdirs = [f for f in dir_content if os.path.isdir(os.path.join(path, f))]
if len(subdirs) == 0:
log.warn(f"No subdir found in data dir")
ok_dir = [d for d in subdirs if re.findall(r"^\d+_.+", d)]
if len(ok_dir) == 0:
log.warning(f"No leagal dataset found. Try find avaliable images")
imgs = get_total_images(path, False)
log.info(f"{len(imgs)} images found")
if len(imgs) > 0:
dataset_path = os.path.join(path, "1_zkz")
os.makedirs(dataset_path)
for i in imgs:
shutil.move(i, dataset_path)
log.info(f"Auto dataset created {dataset_path}")
else:
log.error("No image found in data dir")
return False
return True
def check_training_params(data): def check_training_params(data):
potential_path = [ potential_path = [
"train_data_dir", "reg_data_dir", "output_dir" "train_data_dir", "reg_data_dir", "output_dir"
@ -26,11 +62,16 @@ def check_training_params(data):
return True return True
def get_total_images(path): def get_total_images(path, recursive=True):
if recursive:
image_files = glob.glob(path + '/**/*.jpg', recursive=True) image_files = glob.glob(path + '/**/*.jpg', recursive=True)
image_files += glob.glob(path + '/**/*.jpeg', recursive=True) image_files += glob.glob(path + '/**/*.jpeg', recursive=True)
image_files += glob.glob(path + '/**/*.png', recursive=True) image_files += glob.glob(path + '/**/*.png', recursive=True)
return len(image_files) else:
image_files = glob.glob(path + '/*.jpg')
image_files += glob.glob(path + '/*.jpeg')
image_files += glob.glob(path + '/*.png')
return image_files
def is_installed(package): def is_installed(package):
@ -82,6 +123,7 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st
def run_pip(command, desc=None, live=False): def run_pip(command, desc=None, live=False):
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live) return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run(file: str) -> bool: def check_run(file: str) -> bool:
result = subprocess.run([python_bin, file], capture_output=True, shell=False) result = subprocess.run([python_bin, file], capture_output=True, shell=False)
log.info(result.stdout.decode("utf-8").strip()) log.info(result.stdout.decode("utf-8").strip())