lora-scripts/mikazuki/utils/train_utils.py

119 lines
3.3 KiB
Python

import glob
import os
import re
import shutil
import sys
from mikazuki.log import log
python_bin = sys.executable
def is_promopt_like(s):
for p in ["--n", "--s", "--l", "--d"]:
if p in s:
return True
return False
def validate_model(model_name: str):
if os.path.exists(model_name):
try:
with open(model_name, "rb") as f:
content = f.read(1024 * 200)
if b"model.diffusion_model" in content or b"cond_stage_model.transformer.text_model" in content:
return True, "ok"
if b"lora_unet" in content or b"lora_te" in content:
return False, "pretrained model is not a Stable Diffusion checkpoint / 校验失败:底模不是 Stable Diffusion 模型"
except Exception as e:
log.warn(f"model file {model_name} can't open: {e}")
return True, ""
return True, "ok"
# huggerface model repo
if model_name.count("/") <= 1:
return True, "ok"
return False, "model not found"
def validate_data_dir(path):
if not os.path.exists(path):
log.error(f"Data dir {path} not exists, check your params")
return False
dir_content = os.listdir(path)
if len(dir_content) == 0:
log.error(f"Data dir {path} is empty, check your params")
subdirs = [f for f in dir_content if os.path.isdir(os.path.join(path, f))]
if len(subdirs) == 0:
log.warn(f"No subdir found in data dir")
ok_dir = [d for d in subdirs if re.findall(r"^\d+_.+", d)]
if len(ok_dir) == 0:
log.warning(f"No leagal dataset found. Try find avaliable images")
imgs = get_total_images(path, False)
captions = glob.glob(path + '/*.txt')
log.info(f"{len(imgs)} images found, {len(captions)} captions found")
if len(imgs) > 0:
num_repeat = suggest_num_repeat(len(imgs))
dataset_path = os.path.join(path, f"{num_repeat}_zkz")
os.makedirs(dataset_path)
for i in imgs:
shutil.move(i, dataset_path)
if len(captions) > 0:
for c in captions:
shutil.move(c, dataset_path)
log.info(f"Auto dataset created {dataset_path}")
else:
log.error("No image found in data dir")
return False
return True
def suggest_num_repeat(img_count):
if img_count <= 10:
return 7
elif 10 < img_count <= 50:
return 5
elif 50 < img_count <= 100:
return 3
return 1
def check_training_params(data):
potential_path = [
"train_data_dir", "reg_data_dir", "output_dir"
]
file_paths = [
"sample_prompts"
]
for p in potential_path:
if p in data and not os.path.exists(data[p]):
return False
for f in file_paths:
if f in data and not os.path.exists(data[f]):
return False
return True
def get_total_images(path, recursive=True):
if recursive:
image_files = glob.glob(path + '/**/*.jpg', recursive=True)
image_files += glob.glob(path + '/**/*.jpeg', recursive=True)
image_files += glob.glob(path + '/**/*.png', recursive=True)
else:
image_files = glob.glob(path + '/*.jpg')
image_files += glob.glob(path + '/*.jpeg')
image_files += glob.glob(path + '/*.png')
return image_files