pull/658/head
Akegarasu 2025-03-18 22:05:35 +08:00
parent bf945edba0
commit cf6dc84cfe
No known key found for this signature in database
GPG Key ID: DACA951FEBA569A2
2 changed files with 116 additions and 40 deletions

2
gui.py
View File

@ -85,7 +85,7 @@ def launch():
import uvicorn
log.info(f"Server started at http://{args.host}:{args.port}")
uvicorn.run("mikazuki.app:app", host=args.host, port=args.port, log_level="error")
uvicorn.run("mikazuki.app:app", host=args.host, port=args.port, log_level="error", reload=args.dev)
if __name__ == "__main__":

View File

@ -4,6 +4,8 @@ import os
import re
import shutil
import sys
import json
from typing import Dict
from mikazuki.log import log
@ -20,6 +22,53 @@ class ModelType(Enum):
LoRA = 10
MODEL_SIGNATURE = [
{
"type": ModelType.FLUX,
"signature": [
"double_blocks.0.img_mlp.0.weight",
"guidance_in.in_layer.weight"
"model.diffusion_model.double_blocks",
"double_blocks.0.img_attn.norm.query_norm.scale",
]
},
{
"type": ModelType.SD3,
"signature": [
"model.diffusion_model.x_embedder.proj.weight",
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight"
]
},
{
"type": ModelType.SDXL,
"signature": [
"conditioner.embedders.1.model.transformer.resblocks",
]
},
{
"type": ModelType.SD15,
"signature": [
"model.diffusion_model",
"cond_stage_model.transformer.text_model",
]
},
{
"type": ModelType.LoRA,
"signature": [
"lora_te_text_model_encoder",
"lora_unet_up_blocks"
"lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k.alpha",
"lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k.lora_up.weight",
# more common signature
"lora_unet",
"lora_te",
"lora_A.weight",
]
}
]
def is_promopt_like(s):
for p in ["--n", "--s", "--l", "--d"]:
if p in s:
@ -27,45 +76,7 @@ def is_promopt_like(s):
return False
def validate_model(model_name: str, training_type: str = "sd-lora"):
if os.path.exists(model_name):
if os.path.isdir(model_name):
files = os.listdir(model_name)
if "model_index.json" in files or "unet" in model_name:
return True, "ok"
else:
log.warning("Can't find model, is this a huggingface model folder?")
return True, "ok"
try:
with open(model_name, "rb") as f:
content = f.read(1024 * 1000)
model_type = match_model_type(content)
if model_type == ModelType.UNKNOWN:
log.error(f"Can't match model type from {model_name}")
if model_type not in [ModelType.SD15, ModelType.SD2, ModelType.SD3, ModelType.SDXL, ModelType.FLUX]:
return False, "Pretrained model is not a Stable Diffusion or Flux checkpoint / 校验失败:底模不是 Stable Diffusion 或 Flux 模型"
if model_type == ModelType.SDXL and training_type == "sd-lora":
return False, "Pretrained model is SDXL, but you are training with SD1.5 LoRA / 校验失败:你选择的是 SD1.5 LoRA 训练,但预训练模型是 SDXL。请前往专家模式选择正确的模型种类。"
except Exception as e:
log.warning(f"model file {model_name} can't open: {e}")
return True, ""
return True, "ok"
# huggingface model repo
if model_name.count("/") == 1 \
and not model_name[0] in [".", "/"] \
and not model_name.split(".")[-1] in ["pt", "pth", "ckpt", "safetensors"]:
return True, "ok"
return False, "model not found"
def match_model_type(sig_content: bytes):
def match_model_type_legacy(sig_content: bytes):
if b"model.diffusion_model.double_blocks" in sig_content or b"double_blocks.0.img_attn.norm.query_norm.scale" in sig_content:
return ModelType.FLUX
@ -84,6 +95,71 @@ def match_model_type(sig_content: bytes):
return ModelType.UNKNOWN
def read_safetensors_metadata(path) -> Dict:
if not os.path.exists(path):
log.error(f"Can't find safetensors metadata file {path}")
return None
with open(path, "rb") as f:
meta_length = int.from_bytes(f.read(8), "little")
meta = f.read(meta_length)
return json.loads(meta)
def guess_model_type(path):
if path.endswith("safetensors"):
metadata = read_safetensors_metadata(path)
model_keys = "\n".join(metadata.keys())
for m in MODEL_SIGNATURE:
if any([k in model_keys for k in m["signature"]]):
return m["type"]
return ModelType.UNKNOWN
if path.endswith("pt") or path.endswith("ckpt"):
with open(path, "rb") as f:
content = f.read(1024 * 1000)
return match_model_type_legacy(content)
def validate_model(model_name: str, training_type: str = "sd-lora"):
if os.path.exists(model_name):
if os.path.isdir(model_name):
files = os.listdir(model_name)
if "model_index.json" in files or "unet" in model_name:
return True, "ok"
else:
log.warning("Can't find model, is this a huggingface model folder?")
return True, "ok"
model_type = ModelType.UNKNOWN
try:
model_type = guess_model_type(model_name)
except Exception as e:
log.warning(f"model file {model_name} can't open: {e}")
return True, ""
if model_type == ModelType.UNKNOWN:
log.error(f"Can't match model type from {model_name}")
if model_type not in [ModelType.SD15, ModelType.SD2, ModelType.SD3, ModelType.SDXL, ModelType.FLUX]:
return False, "Pretrained model is not a Stable Diffusion or Flux checkpoint / 校验失败:底模不是 Stable Diffusion 或 Flux 模型"
if model_type == ModelType.SDXL and training_type == "sd-lora":
return False, "Pretrained model is SDXL, but you are training with SD1.5 LoRA / 校验失败:你选择的是 SD1.5 LoRA 训练,但预训练模型是 SDXL。请前往专家模式选择正确的模型种类。"
return True, "ok"
# huggingface model repo
if model_name.count("/") == 1 \
and not model_name[0] in [".", "/"] \
and not model_name.split(".")[-1] in ["pt", "pth", "ckpt", "safetensors"]:
return True, "ok"
return False, "model not found"
def validate_data_dir(path):
if not os.path.exists(path):
log.error(f"Data dir {path} not exists, check your params")