feat: auto check & create dataset folder
parent
7603fb87f9
commit
5dab6a87f6
|
|
@ -89,13 +89,8 @@ async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
|
|||
toml_data = await request.body()
|
||||
j = json.loads(toml_data.decode("utf-8"))
|
||||
|
||||
# ok = utils.check_training_params(j)
|
||||
# if not ok:
|
||||
# lock.release()
|
||||
# print("训练目录校验失败,请确保填写的目录存在")
|
||||
# return {"status": "fail", "detail": "训练目录校验失败,请确保填写的目录存在"}
|
||||
|
||||
suggest_cpu_threads = 8 if utils.get_total_images(j["train_data_dir"]) > 100 else 2
|
||||
utils.validate_data_dir(j["train_data_dir"])
|
||||
suggest_cpu_threads = 8 if len(utils.get_total_images(j["train_data_dir"])) > 100 else 2
|
||||
trainer_file = "./sd-scripts/train_network.py"
|
||||
|
||||
if j.pop("model_train_type", "sd-lora") == "sdxl-lora":
|
||||
|
|
@ -117,6 +112,8 @@ async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
|
|||
j["sample_prompts"] = sample_prompts_file
|
||||
log.info(f"Writted promopts to file {sample_prompts_file}")
|
||||
|
||||
|
||||
|
||||
with open(toml_file, "w") as f:
|
||||
f.write(toml.dumps(j))
|
||||
|
||||
|
|
|
|||
|
|
@ -3,12 +3,48 @@ import importlib.util
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import re
|
||||
import shutil
|
||||
from typing import Optional
|
||||
from mikazuki.log import log
|
||||
|
||||
python_bin = sys.executable
|
||||
|
||||
|
||||
def validate_data_dir(path):
|
||||
if not os.path.exists(path):
|
||||
log.error(f"Data dir {path} not exists, check your params")
|
||||
return False
|
||||
|
||||
dir_content = os.listdir(path)
|
||||
|
||||
if len(dir_content) == 0:
|
||||
log.error(f"Data dir {path} is empty, check your params")
|
||||
|
||||
subdirs = [f for f in dir_content if os.path.isdir(os.path.join(path, f))]
|
||||
|
||||
if len(subdirs) == 0:
|
||||
log.warn(f"No subdir found in data dir")
|
||||
|
||||
ok_dir = [d for d in subdirs if re.findall(r"^\d+_.+", d)]
|
||||
|
||||
if len(ok_dir) == 0:
|
||||
log.warning(f"No leagal dataset found. Try find avaliable images")
|
||||
imgs = get_total_images(path, False)
|
||||
log.info(f"{len(imgs)} images found")
|
||||
if len(imgs) > 0:
|
||||
dataset_path = os.path.join(path, "1_zkz")
|
||||
os.makedirs(dataset_path)
|
||||
for i in imgs:
|
||||
shutil.move(i, dataset_path)
|
||||
log.info(f"Auto dataset created {dataset_path}")
|
||||
else:
|
||||
log.error("No image found in data dir")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_training_params(data):
|
||||
potential_path = [
|
||||
"train_data_dir", "reg_data_dir", "output_dir"
|
||||
|
|
@ -26,11 +62,16 @@ def check_training_params(data):
|
|||
return True
|
||||
|
||||
|
||||
def get_total_images(path):
|
||||
def get_total_images(path, recursive=True):
|
||||
if recursive:
|
||||
image_files = glob.glob(path + '/**/*.jpg', recursive=True)
|
||||
image_files += glob.glob(path + '/**/*.jpeg', recursive=True)
|
||||
image_files += glob.glob(path + '/**/*.png', recursive=True)
|
||||
return len(image_files)
|
||||
else:
|
||||
image_files = glob.glob(path + '/*.jpg')
|
||||
image_files += glob.glob(path + '/*.jpeg')
|
||||
image_files += glob.glob(path + '/*.png')
|
||||
return image_files
|
||||
|
||||
|
||||
def is_installed(package):
|
||||
|
|
@ -82,6 +123,7 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st
|
|||
def run_pip(command, desc=None, live=False):
|
||||
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||
|
||||
|
||||
def check_run(file: str) -> bool:
|
||||
result = subprocess.run([python_bin, file], capture_output=True, shell=False)
|
||||
log.info(result.stdout.decode("utf-8").strip())
|
||||
|
|
|
|||
Loading…
Reference in New Issue