feat: auto check & create dataset folder
parent
7603fb87f9
commit
5dab6a87f6
|
|
@ -89,13 +89,8 @@ async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
|
||||||
toml_data = await request.body()
|
toml_data = await request.body()
|
||||||
j = json.loads(toml_data.decode("utf-8"))
|
j = json.loads(toml_data.decode("utf-8"))
|
||||||
|
|
||||||
# ok = utils.check_training_params(j)
|
utils.validate_data_dir(j["train_data_dir"])
|
||||||
# if not ok:
|
suggest_cpu_threads = 8 if len(utils.get_total_images(j["train_data_dir"])) > 100 else 2
|
||||||
# lock.release()
|
|
||||||
# print("训练目录校验失败,请确保填写的目录存在")
|
|
||||||
# return {"status": "fail", "detail": "训练目录校验失败,请确保填写的目录存在"}
|
|
||||||
|
|
||||||
suggest_cpu_threads = 8 if utils.get_total_images(j["train_data_dir"]) > 100 else 2
|
|
||||||
trainer_file = "./sd-scripts/train_network.py"
|
trainer_file = "./sd-scripts/train_network.py"
|
||||||
|
|
||||||
if j.pop("model_train_type", "sd-lora") == "sdxl-lora":
|
if j.pop("model_train_type", "sd-lora") == "sdxl-lora":
|
||||||
|
|
@ -117,6 +112,8 @@ async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
|
||||||
j["sample_prompts"] = sample_prompts_file
|
j["sample_prompts"] = sample_prompts_file
|
||||||
log.info(f"Writted promopts to file {sample_prompts_file}")
|
log.info(f"Writted promopts to file {sample_prompts_file}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
with open(toml_file, "w") as f:
|
with open(toml_file, "w") as f:
|
||||||
f.write(toml.dumps(j))
|
f.write(toml.dumps(j))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,48 @@ import importlib.util
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from mikazuki.log import log
|
from mikazuki.log import log
|
||||||
|
|
||||||
python_bin = sys.executable
|
python_bin = sys.executable
|
||||||
|
|
||||||
|
|
||||||
|
def validate_data_dir(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
log.error(f"Data dir {path} not exists, check your params")
|
||||||
|
return False
|
||||||
|
|
||||||
|
dir_content = os.listdir(path)
|
||||||
|
|
||||||
|
if len(dir_content) == 0:
|
||||||
|
log.error(f"Data dir {path} is empty, check your params")
|
||||||
|
|
||||||
|
subdirs = [f for f in dir_content if os.path.isdir(os.path.join(path, f))]
|
||||||
|
|
||||||
|
if len(subdirs) == 0:
|
||||||
|
log.warn(f"No subdir found in data dir")
|
||||||
|
|
||||||
|
ok_dir = [d for d in subdirs if re.findall(r"^\d+_.+", d)]
|
||||||
|
|
||||||
|
if len(ok_dir) == 0:
|
||||||
|
log.warning(f"No leagal dataset found. Try find avaliable images")
|
||||||
|
imgs = get_total_images(path, False)
|
||||||
|
log.info(f"{len(imgs)} images found")
|
||||||
|
if len(imgs) > 0:
|
||||||
|
dataset_path = os.path.join(path, "1_zkz")
|
||||||
|
os.makedirs(dataset_path)
|
||||||
|
for i in imgs:
|
||||||
|
shutil.move(i, dataset_path)
|
||||||
|
log.info(f"Auto dataset created {dataset_path}")
|
||||||
|
else:
|
||||||
|
log.error("No image found in data dir")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check_training_params(data):
|
def check_training_params(data):
|
||||||
potential_path = [
|
potential_path = [
|
||||||
"train_data_dir", "reg_data_dir", "output_dir"
|
"train_data_dir", "reg_data_dir", "output_dir"
|
||||||
|
|
@ -26,11 +62,16 @@ def check_training_params(data):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_total_images(path):
|
def get_total_images(path, recursive=True):
|
||||||
|
if recursive:
|
||||||
image_files = glob.glob(path + '/**/*.jpg', recursive=True)
|
image_files = glob.glob(path + '/**/*.jpg', recursive=True)
|
||||||
image_files += glob.glob(path + '/**/*.jpeg', recursive=True)
|
image_files += glob.glob(path + '/**/*.jpeg', recursive=True)
|
||||||
image_files += glob.glob(path + '/**/*.png', recursive=True)
|
image_files += glob.glob(path + '/**/*.png', recursive=True)
|
||||||
return len(image_files)
|
else:
|
||||||
|
image_files = glob.glob(path + '/*.jpg')
|
||||||
|
image_files += glob.glob(path + '/*.jpeg')
|
||||||
|
image_files += glob.glob(path + '/*.png')
|
||||||
|
return image_files
|
||||||
|
|
||||||
|
|
||||||
def is_installed(package):
|
def is_installed(package):
|
||||||
|
|
@ -82,6 +123,7 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st
|
||||||
def run_pip(command, desc=None, live=False):
|
def run_pip(command, desc=None, live=False):
|
||||||
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||||
|
|
||||||
|
|
||||||
def check_run(file: str) -> bool:
|
def check_run(file: str) -> bool:
|
||||||
result = subprocess.run([python_bin, file], capture_output=True, shell=False)
|
result = subprocess.run([python_bin, file], capture_output=True, shell=False)
|
||||||
log.info(result.stdout.decode("utf-8").strip())
|
log.info(result.stdout.decode("utf-8").strip())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue