From 5dab6a87f6daf4f6e6ac2397d3bd11d4e1cfbaba Mon Sep 17 00:00:00 2001 From: akiba Date: Fri, 1 Sep 2023 17:20:57 +0800 Subject: [PATCH] feat: auto check & create dataset folder --- mikazuki/app.py | 11 ++++------ mikazuki/utils.py | 52 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/mikazuki/app.py b/mikazuki/app.py index 26efe61..a319831 100644 --- a/mikazuki/app.py +++ b/mikazuki/app.py @@ -89,13 +89,8 @@ async def create_toml_file(request: Request, background_tasks: BackgroundTasks): toml_data = await request.body() j = json.loads(toml_data.decode("utf-8")) - # ok = utils.check_training_params(j) - # if not ok: - # lock.release() - # print("训练目录校验失败,请确保填写的目录存在") - # return {"status": "fail", "detail": "训练目录校验失败,请确保填写的目录存在"} - - suggest_cpu_threads = 8 if utils.get_total_images(j["train_data_dir"]) > 100 else 2 + utils.validate_data_dir(j["train_data_dir"]) + suggest_cpu_threads = 8 if len(utils.get_total_images(j["train_data_dir"])) > 100 else 2 trainer_file = "./sd-scripts/train_network.py" if j.pop("model_train_type", "sd-lora") == "sdxl-lora": @@ -117,6 +112,8 @@ async def create_toml_file(request: Request, background_tasks: BackgroundTasks): j["sample_prompts"] = sample_prompts_file log.info(f"Writted promopts to file {sample_prompts_file}") + + with open(toml_file, "w") as f: f.write(toml.dumps(j)) diff --git a/mikazuki/utils.py b/mikazuki/utils.py index 84c1853..66870b9 100644 --- a/mikazuki/utils.py +++ b/mikazuki/utils.py @@ -3,12 +3,48 @@ import importlib.util import os import subprocess import sys +import re +import shutil from typing import Optional from mikazuki.log import log python_bin = sys.executable +def validate_data_dir(path): + if not os.path.exists(path): + log.error(f"Data dir {path} not exists, check your params") + return False + + dir_content = os.listdir(path) + + if len(dir_content) == 0: + log.error(f"Data dir {path} is empty, check your params") + + subdirs = [f for f in dir_content if os.path.isdir(os.path.join(path, f))] + + if len(subdirs) == 0: + log.warn(f"No subdir found in data dir") + + ok_dir = [d for d in subdirs if re.findall(r"^\d+_.+", d)] + + if len(ok_dir) == 0: + log.warning(f"No leagal dataset found. Try find avaliable images") + imgs = get_total_images(path, False) + log.info(f"{len(imgs)} images found") + if len(imgs) > 0: + dataset_path = os.path.join(path, "1_zkz") + os.makedirs(dataset_path) + for i in imgs: + shutil.move(i, dataset_path) + log.info(f"Auto dataset created {dataset_path}") + else: + log.error("No image found in data dir") + return False + + return True + + def check_training_params(data): potential_path = [ "train_data_dir", "reg_data_dir", "output_dir" @@ -26,11 +62,16 @@ def check_training_params(data): return True -def get_total_images(path): - image_files = glob.glob(path + '/**/*.jpg', recursive=True) - image_files += glob.glob(path + '/**/*.jpeg', recursive=True) - image_files += glob.glob(path + '/**/*.png', recursive=True) - return len(image_files) +def get_total_images(path, recursive=True): + if recursive: + image_files = glob.glob(path + '/**/*.jpg', recursive=True) + image_files += glob.glob(path + '/**/*.jpeg', recursive=True) + image_files += glob.glob(path + '/**/*.png', recursive=True) + else: + image_files = glob.glob(path + '/*.jpg') + image_files += glob.glob(path + '/*.jpeg') + image_files += glob.glob(path + '/*.png') + return image_files def is_installed(package): @@ -82,6 +123,7 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st def run_pip(command, desc=None, live=False): return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live) + def check_run(file: str) -> bool: result = subprocess.run([python_bin, file], capture_output=True, shell=False) log.info(result.stdout.decode("utf-8").strip())