fix #631
parent
bf945edba0
commit
cf6dc84cfe
2
gui.py
2
gui.py
|
|
@ -85,7 +85,7 @@ def launch():
|
|||
|
||||
import uvicorn
|
||||
log.info(f"Server started at http://{args.host}:{args.port}")
|
||||
uvicorn.run("mikazuki.app:app", host=args.host, port=args.port, log_level="error")
|
||||
uvicorn.run("mikazuki.app:app", host=args.host, port=args.port, log_level="error", reload=args.dev)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import os
|
|||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from mikazuki.log import log
|
||||
|
||||
|
|
@ -20,6 +22,53 @@ class ModelType(Enum):
|
|||
LoRA = 10
|
||||
|
||||
|
||||
MODEL_SIGNATURE = [
|
||||
{
|
||||
"type": ModelType.FLUX,
|
||||
"signature": [
|
||||
"double_blocks.0.img_mlp.0.weight",
|
||||
"guidance_in.in_layer.weight"
|
||||
"model.diffusion_model.double_blocks",
|
||||
"double_blocks.0.img_attn.norm.query_norm.scale",
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": ModelType.SD3,
|
||||
"signature": [
|
||||
"model.diffusion_model.x_embedder.proj.weight",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": ModelType.SDXL,
|
||||
"signature": [
|
||||
"conditioner.embedders.1.model.transformer.resblocks",
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": ModelType.SD15,
|
||||
"signature": [
|
||||
"model.diffusion_model",
|
||||
"cond_stage_model.transformer.text_model",
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": ModelType.LoRA,
|
||||
"signature": [
|
||||
"lora_te_text_model_encoder",
|
||||
"lora_unet_up_blocks"
|
||||
"lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k.alpha",
|
||||
"lora_unet_input_blocks_4_1_transformer_blocks_0_attn1_to_k.lora_up.weight",
|
||||
|
||||
# more common signature
|
||||
"lora_unet",
|
||||
"lora_te",
|
||||
"lora_A.weight",
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def is_promopt_like(s):
|
||||
for p in ["--n", "--s", "--l", "--d"]:
|
||||
if p in s:
|
||||
|
|
@ -27,45 +76,7 @@ def is_promopt_like(s):
|
|||
return False
|
||||
|
||||
|
||||
def validate_model(model_name: str, training_type: str = "sd-lora"):
|
||||
if os.path.exists(model_name):
|
||||
if os.path.isdir(model_name):
|
||||
files = os.listdir(model_name)
|
||||
if "model_index.json" in files or "unet" in model_name:
|
||||
return True, "ok"
|
||||
else:
|
||||
log.warning("Can't find model, is this a huggingface model folder?")
|
||||
return True, "ok"
|
||||
try:
|
||||
with open(model_name, "rb") as f:
|
||||
content = f.read(1024 * 1000)
|
||||
model_type = match_model_type(content)
|
||||
|
||||
if model_type == ModelType.UNKNOWN:
|
||||
log.error(f"Can't match model type from {model_name}")
|
||||
|
||||
if model_type not in [ModelType.SD15, ModelType.SD2, ModelType.SD3, ModelType.SDXL, ModelType.FLUX]:
|
||||
return False, "Pretrained model is not a Stable Diffusion or Flux checkpoint / 校验失败:底模不是 Stable Diffusion 或 Flux 模型"
|
||||
|
||||
if model_type == ModelType.SDXL and training_type == "sd-lora":
|
||||
return False, "Pretrained model is SDXL, but you are training with SD1.5 LoRA / 校验失败:你选择的是 SD1.5 LoRA 训练,但预训练模型是 SDXL。请前往专家模式选择正确的模型种类。"
|
||||
|
||||
except Exception as e:
|
||||
log.warning(f"model file {model_name} can't open: {e}")
|
||||
return True, ""
|
||||
|
||||
return True, "ok"
|
||||
|
||||
# huggingface model repo
|
||||
if model_name.count("/") == 1 \
|
||||
and not model_name[0] in [".", "/"] \
|
||||
and not model_name.split(".")[-1] in ["pt", "pth", "ckpt", "safetensors"]:
|
||||
return True, "ok"
|
||||
|
||||
return False, "model not found"
|
||||
|
||||
|
||||
def match_model_type(sig_content: bytes):
|
||||
def match_model_type_legacy(sig_content: bytes):
|
||||
if b"model.diffusion_model.double_blocks" in sig_content or b"double_blocks.0.img_attn.norm.query_norm.scale" in sig_content:
|
||||
return ModelType.FLUX
|
||||
|
||||
|
|
@ -84,6 +95,71 @@ def match_model_type(sig_content: bytes):
|
|||
return ModelType.UNKNOWN
|
||||
|
||||
|
||||
def read_safetensors_metadata(path) -> Dict:
|
||||
if not os.path.exists(path):
|
||||
log.error(f"Can't find safetensors metadata file {path}")
|
||||
return None
|
||||
|
||||
with open(path, "rb") as f:
|
||||
meta_length = int.from_bytes(f.read(8), "little")
|
||||
meta = f.read(meta_length)
|
||||
return json.loads(meta)
|
||||
|
||||
|
||||
def guess_model_type(path):
|
||||
if path.endswith("safetensors"):
|
||||
metadata = read_safetensors_metadata(path)
|
||||
model_keys = "\n".join(metadata.keys())
|
||||
for m in MODEL_SIGNATURE:
|
||||
if any([k in model_keys for k in m["signature"]]):
|
||||
return m["type"]
|
||||
|
||||
return ModelType.UNKNOWN
|
||||
|
||||
if path.endswith("pt") or path.endswith("ckpt"):
|
||||
with open(path, "rb") as f:
|
||||
content = f.read(1024 * 1000)
|
||||
return match_model_type_legacy(content)
|
||||
|
||||
|
||||
def validate_model(model_name: str, training_type: str = "sd-lora"):
|
||||
if os.path.exists(model_name):
|
||||
if os.path.isdir(model_name):
|
||||
files = os.listdir(model_name)
|
||||
if "model_index.json" in files or "unet" in model_name:
|
||||
return True, "ok"
|
||||
else:
|
||||
log.warning("Can't find model, is this a huggingface model folder?")
|
||||
return True, "ok"
|
||||
|
||||
model_type = ModelType.UNKNOWN
|
||||
|
||||
try:
|
||||
model_type = guess_model_type(model_name)
|
||||
except Exception as e:
|
||||
log.warning(f"model file {model_name} can't open: {e}")
|
||||
return True, ""
|
||||
|
||||
if model_type == ModelType.UNKNOWN:
|
||||
log.error(f"Can't match model type from {model_name}")
|
||||
|
||||
if model_type not in [ModelType.SD15, ModelType.SD2, ModelType.SD3, ModelType.SDXL, ModelType.FLUX]:
|
||||
return False, "Pretrained model is not a Stable Diffusion or Flux checkpoint / 校验失败:底模不是 Stable Diffusion 或 Flux 模型"
|
||||
|
||||
if model_type == ModelType.SDXL and training_type == "sd-lora":
|
||||
return False, "Pretrained model is SDXL, but you are training with SD1.5 LoRA / 校验失败:你选择的是 SD1.5 LoRA 训练,但预训练模型是 SDXL。请前往专家模式选择正确的模型种类。"
|
||||
|
||||
return True, "ok"
|
||||
|
||||
# huggingface model repo
|
||||
if model_name.count("/") == 1 \
|
||||
and not model_name[0] in [".", "/"] \
|
||||
and not model_name.split(".")[-1] in ["pt", "pth", "ckpt", "safetensors"]:
|
||||
return True, "ok"
|
||||
|
||||
return False, "model not found"
|
||||
|
||||
|
||||
def validate_data_dir(path):
|
||||
if not os.path.exists(path):
|
||||
log.error(f"Data dir {path} not exists, check your params")
|
||||
|
|
|
|||
Loading…
Reference in New Issue