feat: auto check & create dataset folder

pull/222/head
akiba 2023-09-01 17:20:57 +08:00
parent 7603fb87f9
commit 5dab6a87f6
No known key found for this signature in database
GPG Key ID: 9D600258808ACBCD
2 changed files with 51 additions and 12 deletions

View File

@ -89,13 +89,8 @@ async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
toml_data = await request.body()
j = json.loads(toml_data.decode("utf-8"))
# ok = utils.check_training_params(j)
# if not ok:
# lock.release()
# print("训练目录校验失败,请确保填写的目录存在")
# return {"status": "fail", "detail": "训练目录校验失败,请确保填写的目录存在"}
suggest_cpu_threads = 8 if utils.get_total_images(j["train_data_dir"]) > 100 else 2
utils.validate_data_dir(j["train_data_dir"])
suggest_cpu_threads = 8 if len(utils.get_total_images(j["train_data_dir"])) > 100 else 2
trainer_file = "./sd-scripts/train_network.py"
if j.pop("model_train_type", "sd-lora") == "sdxl-lora":
@ -117,6 +112,8 @@ async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
j["sample_prompts"] = sample_prompts_file
log.info(f"Writted promopts to file {sample_prompts_file}")
with open(toml_file, "w") as f:
f.write(toml.dumps(j))

View File

@ -3,12 +3,48 @@ import importlib.util
import os
import subprocess
import sys
import re
import shutil
from typing import Optional
from mikazuki.log import log
python_bin = sys.executable
def validate_data_dir(path):
if not os.path.exists(path):
log.error(f"Data dir {path} not exists, check your params")
return False
dir_content = os.listdir(path)
if len(dir_content) == 0:
log.error(f"Data dir {path} is empty, check your params")
subdirs = [f for f in dir_content if os.path.isdir(os.path.join(path, f))]
if len(subdirs) == 0:
log.warn(f"No subdir found in data dir")
ok_dir = [d for d in subdirs if re.findall(r"^\d+_.+", d)]
if len(ok_dir) == 0:
log.warning(f"No leagal dataset found. Try find avaliable images")
imgs = get_total_images(path, False)
log.info(f"{len(imgs)} images found")
if len(imgs) > 0:
dataset_path = os.path.join(path, "1_zkz")
os.makedirs(dataset_path)
for i in imgs:
shutil.move(i, dataset_path)
log.info(f"Auto dataset created {dataset_path}")
else:
log.error("No image found in data dir")
return False
return True
def check_training_params(data):
potential_path = [
"train_data_dir", "reg_data_dir", "output_dir"
@ -26,11 +62,16 @@ def check_training_params(data):
return True
def get_total_images(path):
def get_total_images(path, recursive=True):
if recursive:
image_files = glob.glob(path + '/**/*.jpg', recursive=True)
image_files += glob.glob(path + '/**/*.jpeg', recursive=True)
image_files += glob.glob(path + '/**/*.png', recursive=True)
return len(image_files)
else:
image_files = glob.glob(path + '/*.jpg')
image_files += glob.glob(path + '/*.jpeg')
image_files += glob.glob(path + '/*.png')
return image_files
def is_installed(package):
@ -82,6 +123,7 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st
def run_pip(command, desc=None, live=False):
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run(file: str) -> bool:
result = subprocess.run([python_bin, file], capture_output=True, shell=False)
log.info(result.stdout.decode("utf-8").strip())