From cf6dc84cfe6ee367f59e9b3801ac532e767acc99 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Tue, 18 Mar 2025 22:05:35 +0800 Subject: [PATCH] fix #631 --- gui.py | 2 +- mikazuki/utils/train_utils.py | 154 +++++++++++++++++++++++++--------- 2 files changed, 116 insertions(+), 40 deletions(-) diff --git a/gui.py b/gui.py index 80c87e5..e4ea68b 100644 --- a/gui.py +++ b/gui.py @@ -85,7 +85,7 @@ def launch(): import uvicorn log.info(f"Server started at http://{args.host}:{args.port}") - uvicorn.run("mikazuki.app:app", host=args.host, port=args.port, log_level="error") + uvicorn.run("mikazuki.app:app", host=args.host, port=args.port, log_level="error", reload=args.dev) if __name__ == "__main__": diff --git a/mikazuki/utils/train_utils.py b/mikazuki/utils/train_utils.py index 2e27c97..452b977 100644 --- a/mikazuki/utils/train_utils.py +++ b/mikazuki/utils/train_utils.py @@ -4,6 +4,8 @@ import os import re import shutil import sys +import json +from typing import Dict from mikazuki.log import log @@ -20,6 +22,53 @@ class ModelType(Enum): LoRA = 10 +MODEL_SIGNATURE = [ + { + "type": ModelType.FLUX, + "signature": [ + "double_blocks.0.img_mlp.0.weight", + "guidance_in.in_layer.weight" + "model.diffusion_model.double_blocks", + "double_blocks.0.img_attn.norm.query_norm.scale", + ] + }, + { + "type": ModelType.SD3, + "signature": [ + "model.diffusion_model.x_embedder.proj.weight", + "model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight" + ] + }, + { + "type": ModelType.SDXL, + "signature": [ + "conditioner.embedders.1.model.transformer.resblocks", + ] + }, + { + "type": ModelType.SD15, + "signature": [ + "model.diffusion_model", + "cond_stage_model.transformer.text_model", + ] + }, + { + "type": ModelType.LoRA, + "signature": [ + "lora_te_text_model_encoder", + "lora_unet_up_blocks" + "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k.alpha", + "lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k.lora_up.weight", + + # more common signature + "lora_unet", + "lora_te", + "lora_A.weight", + ] + } +] + + def is_promopt_like(s): for p in ["--n", "--s", "--l", "--d"]: if p in s: @@ -27,45 +76,7 @@ def is_promopt_like(s): return False -def validate_model(model_name: str, training_type: str = "sd-lora"): - if os.path.exists(model_name): - if os.path.isdir(model_name): - files = os.listdir(model_name) - if "model_index.json" in files or "unet" in model_name: - return True, "ok" - else: - log.warning("Can't find model, is this a huggingface model folder?") - return True, "ok" - try: - with open(model_name, "rb") as f: - content = f.read(1024 * 1000) - model_type = match_model_type(content) - - if model_type == ModelType.UNKNOWN: - log.error(f"Can't match model type from {model_name}") - - if model_type not in [ModelType.SD15, ModelType.SD2, ModelType.SD3, ModelType.SDXL, ModelType.FLUX]: - return False, "Pretrained model is not a Stable Diffusion or Flux checkpoint / 校验失败:底模不是 Stable Diffusion 或 Flux 模型" - - if model_type == ModelType.SDXL and training_type == "sd-lora": - return False, "Pretrained model is SDXL, but you are training with SD1.5 LoRA / 校验失败:你选择的是 SD1.5 LoRA 训练,但预训练模型是 SDXL。请前往专家模式选择正确的模型种类。" - - except Exception as e: - log.warning(f"model file {model_name} can't open: {e}") - return True, "" - - return True, "ok" - - # huggingface model repo - if model_name.count("/") == 1 \ - and not model_name[0] in [".", "/"] \ - and not model_name.split(".")[-1] in ["pt", "pth", "ckpt", "safetensors"]: - return True, "ok" - - return False, "model not found" - - -def match_model_type(sig_content: bytes): +def match_model_type_legacy(sig_content: bytes): if b"model.diffusion_model.double_blocks" in sig_content or b"double_blocks.0.img_attn.norm.query_norm.scale" in sig_content: return ModelType.FLUX @@ -84,6 +95,71 @@ def match_model_type(sig_content: bytes): return ModelType.UNKNOWN +def read_safetensors_metadata(path) -> Dict: + if not os.path.exists(path): + log.error(f"Can't find safetensors metadata file {path}") + return None + + with open(path, "rb") as f: + meta_length = int.from_bytes(f.read(8), "little") + meta = f.read(meta_length) + return json.loads(meta) + + +def guess_model_type(path): + if path.endswith("safetensors"): + metadata = read_safetensors_metadata(path) + model_keys = "\n".join(metadata.keys()) + for m in MODEL_SIGNATURE: + if any([k in model_keys for k in m["signature"]]): + return m["type"] + + return ModelType.UNKNOWN + + if path.endswith("pt") or path.endswith("ckpt"): + with open(path, "rb") as f: + content = f.read(1024 * 1000) + return match_model_type_legacy(content) + + +def validate_model(model_name: str, training_type: str = "sd-lora"): + if os.path.exists(model_name): + if os.path.isdir(model_name): + files = os.listdir(model_name) + if "model_index.json" in files or "unet" in model_name: + return True, "ok" + else: + log.warning("Can't find model, is this a huggingface model folder?") + return True, "ok" + + model_type = ModelType.UNKNOWN + + try: + model_type = guess_model_type(model_name) + except Exception as e: + log.warning(f"model file {model_name} can't open: {e}") + return True, "" + + if model_type == ModelType.UNKNOWN: + log.error(f"Can't match model type from {model_name}") + + if model_type not in [ModelType.SD15, ModelType.SD2, ModelType.SD3, ModelType.SDXL, ModelType.FLUX]: + return False, "Pretrained model is not a Stable Diffusion or Flux checkpoint / 校验失败:底模不是 Stable Diffusion 或 Flux 模型" + + if model_type == ModelType.SDXL and training_type == "sd-lora": + return False, "Pretrained model is SDXL, but you are training with SD1.5 LoRA / 校验失败:你选择的是 SD1.5 LoRA 训练,但预训练模型是 SDXL。请前往专家模式选择正确的模型种类。" + + return True, "ok" + + # huggingface model repo + if model_name.count("/") == 1 \ + and not model_name[0] in [".", "/"] \ + and not model_name.split(".")[-1] in ["pt", "pth", "ckpt", "safetensors"]: + return True, "ok" + + return False, "model not found" + + def validate_data_dir(path): if not os.path.exists(path): log.error(f"Data dir {path} not exists, check your params")