lora-scripts/mikazuki/utils/train_utils.py

175 lines
5.3 KiB
Python

from enum import Enum
import glob
import os
import re
import shutil
import sys
from mikazuki.log import log
python_bin = sys.executable
class ModelType(Enum):
UNKNOWN = -1
SD15 = 1
SD2 = 2
SDXL = 3
SD3 = 4
FLUX = 5
LoRA = 10
def is_promopt_like(s):
for p in ["--n", "--s", "--l", "--d"]:
if p in s:
return True
return False
def validate_model(model_name: str, training_type: str = "sd-lora"):
if os.path.exists(model_name):
if os.path.isdir(model_name):
files = os.listdir(model_name)
if "model_index.json" in files or "unet" in model_name:
return True, "ok"
else:
log.warning("Can't find model, is this a huggingface model folder?")
return True, "ok"
try:
with open(model_name, "rb") as f:
content = f.read(1024 * 1000)
model_type = match_model_type(content)
if model_type == ModelType.UNKNOWN:
log.error(f"Can't match model type from {model_name}")
if model_type not in [ModelType.SD15, ModelType.SD2, ModelType.SD3, ModelType.SDXL, ModelType.FLUX]:
return False, "Pretrained model is not a Stable Diffusion or Flux checkpoint / 校验失败:底模不是 Stable Diffusion 或 Flux 模型"
if model_type == ModelType.SDXL and training_type == "sd-lora":
return False, "Pretrained model is SDXL, but you are training with SD1.5 LoRA / 校验失败:你选择的是 SD1.5 LoRA 训练,但预训练模型是 SDXL。请前往专家模式选择正确的模型种类。"
except Exception as e:
log.warning(f"model file {model_name} can't open: {e}")
return True, ""
return True, "ok"
# huggingface model repo
if model_name.count("/") == 1 \
and not model_name[0] in [".", "/"] \
and not model_name.split(".")[-1] in ["pt", "pth", "ckpt", "safetensors"]:
return True, "ok"
return False, "model not found"
def match_model_type(sig_content: bytes):
if b"model.diffusion_model.double_blocks" in sig_content or b"double_blocks.0.img_attn.norm.query_norm.scale" in sig_content:
return ModelType.FLUX
if b"model.diffusion_model.x_embedder.proj.weight" in sig_content:
return ModelType.SD3
if b"conditioner.embedders.1.model.transformer.resblocks" in sig_content:
return ModelType.SDXL
if b"model.diffusion_model" in sig_content or b"cond_stage_model.transformer.text_model" in sig_content:
return ModelType.SD15
if b"lora_unet" in sig_content or b"lora_te" in sig_content:
return ModelType.LoRA
return ModelType.UNKNOWN
def validate_data_dir(path):
if not os.path.exists(path):
log.error(f"Data dir {path} not exists, check your params")
return False
dir_content = os.listdir(path)
if len(dir_content) == 0:
log.error(f"Data dir {path} is empty, check your params")
subdirs = [f for f in dir_content if os.path.isdir(os.path.join(path, f))]
if len(subdirs) == 0:
log.warning(f"No subdir found in data dir")
ok_dir = [d for d in subdirs if re.findall(r"^\d+_.+", d)]
if len(ok_dir) > 0:
log.info(f"Found {len(ok_dir)} legal dataset")
return True
if len(ok_dir) == 0:
log.warning(f"No leagal dataset found. Try find avaliable images")
imgs = get_total_images(path, False)
captions = glob.glob(path + '/*.txt')
log.info(f"{len(imgs)} images found, {len(captions)} captions found")
if len(imgs) > 0:
num_repeat = suggest_num_repeat(len(imgs))
dataset_path = os.path.join(path, f"{num_repeat}_zkz")
os.makedirs(dataset_path)
for i in imgs:
shutil.move(i, dataset_path)
if len(captions) > 0:
for c in captions:
shutil.move(c, dataset_path)
log.info(f"Auto dataset created {dataset_path}")
else:
log.error("No image found in data dir")
return False
return True
def suggest_num_repeat(img_count):
if img_count <= 10:
return 7
elif 10 < img_count <= 50:
return 5
elif 50 < img_count <= 100:
return 3
return 1
def check_training_params(data):
potential_path = [
"train_data_dir", "reg_data_dir", "output_dir"
]
file_paths = [
"sample_prompts"
]
for p in potential_path:
if p in data and not os.path.exists(data[p]):
return False
for f in file_paths:
if f in data and not os.path.exists(data[f]):
return False
return True
def get_total_images(path, recursive=True):
if recursive:
image_files = glob.glob(path + '/**/*.jpg', recursive=True)
image_files += glob.glob(path + '/**/*.jpeg', recursive=True)
image_files += glob.glob(path + '/**/*.png', recursive=True)
else:
image_files = glob.glob(path + '/*.jpg')
image_files += glob.glob(path + '/*.jpeg')
image_files += glob.glob(path + '/*.png')
return image_files
def fix_config_types(config: dict):
keep_float_params = ["guidance_scale", "sigmoid_scale", "discrete_flow_shift"]
for k in keep_float_params:
if k in config:
config[k] = float(config[k])