import asyncio import hashlib import json import os from datetime import datetime from pathlib import Path import toml from fastapi import APIRouter, BackgroundTasks, Request from starlette.requests import Request import mikazuki.process as process from mikazuki import launch_utils from mikazuki.app.config import app_config from mikazuki.app.models import (APIResponse, APIResponseFail, APIResponseSuccess, TaggerInterrogateRequest) from mikazuki.log import log from mikazuki.tagger.interrogator import (available_interrogators, on_interrogate) from mikazuki.tasks import tm from mikazuki.utils import train_utils from mikazuki.utils.devices import printable_devices from mikazuki.utils.tk_window import (open_directory_selector, open_file_selector) router = APIRouter() avaliable_scripts = [ "networks/extract_lora_from_models.py", "networks/extract_lora_from_dylora.py", "networks/merge_lora.py", "tools/merge_models.py", ] avaliable_schemas = [] trainer_mapping = { "sd-lora": "./sd-scripts/train_network.py", "sdxl-lora": "./sd-scripts/sdxl_train_network.py", "sd3-lora": "./sd-scripts/sd3_train_network.py", "flux-lora": "./sd-scripts/flux_train_network.py", "sd-dreambooth": "./sd-scripts/train_db.py", "sdxl-finetune": "./sd-scripts/sdxl_train.py", } async def load_schemas(): avaliable_schemas.clear() schema_dir = os.path.join(os.getcwd(), "mikazuki", "schema") schemas = os.listdir(schema_dir) def lambda_hash(x): return hashlib.md5(x.encode()).hexdigest() for schema_name in schemas: with open(os.path.join(schema_dir, schema_name), encoding="utf-8") as f: content = f.read() avaliable_schemas.append({ "name": schema_name.strip(".ts"), "schema": content, "hash": lambda_hash(content) }) @router.post("/run") async def create_toml_file(request: Request): timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") toml_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}.toml") json_data = await request.body() config: dict = json.loads(json_data.decode("utf-8")) train_utils.fix_config_types(config) gpu_ids = config.pop("gpu_ids", None) suggest_cpu_threads = 8 if len(train_utils.get_total_images(config["train_data_dir"])) > 200 else 2 model_train_type = config.pop("model_train_type", "sd-lora") trainer_file = trainer_mapping[model_train_type] if model_train_type != "sdxl-finetune": if not train_utils.validate_data_dir(config["train_data_dir"]): return APIResponseFail(message="训练数据集路径不存在或没有图片,请检查目录。") validated, message = train_utils.validate_model(config["pretrained_model_name_or_path"], model_train_type) if not validated: return APIResponseFail(message=message) sample_prompts = config.get("sample_prompts", None) if sample_prompts is not None and not os.path.exists(sample_prompts) and train_utils.is_promopt_like(sample_prompts): sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt") with open(sample_prompts_file, "w", encoding="utf-8") as f: f.write(sample_prompts) config["sample_prompts"] = sample_prompts_file log.info(f"Wrote promopts to file {sample_prompts_file}") with open(toml_file, "w", encoding="utf-8") as f: f.write(toml.dumps(config)) result = process.run_train(toml_file, trainer_file, gpu_ids, suggest_cpu_threads) return result @router.post("/run_script") async def run_script(request: Request, background_tasks: BackgroundTasks): paras = await request.body() j = json.loads(paras.decode("utf-8")) script_name = j["script_name"] if script_name not in avaliable_scripts: return APIResponseFail(message="Script not found") del j["script_name"] result = [] for k, v in j.items(): result.append(f"--{k}") if not isinstance(v, bool): value = str(v) if " " in value: value = f'"{v}"' result.append(value) script_args = " ".join(result) script_path = Path(os.getcwd()) / "sd-scripts" / script_name cmd = f"{launch_utils.python_bin} {script_path} {script_args}" background_tasks.add_task(launch_utils.run, cmd) return APIResponseSuccess() @router.post("/interrogate") async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: BackgroundTasks): interrogator = available_interrogators.get(req.interrogator_model, available_interrogators["wd14-convnextv2-v2"]) background_tasks.add_task( on_interrogate, image=None, batch_input_glob=req.path, batch_input_recursive=req.batch_input_recursive, batch_output_dir="", batch_output_filename_format="[name].[output_extension]", batch_output_action_on_conflict=req.batch_output_action_on_conflict, batch_remove_duplicated_tag=True, batch_output_save_json=False, interrogator=interrogator, threshold=req.threshold, additional_tags=req.additional_tags, exclude_tags=req.exclude_tags, sort_by_alphabetical_order=False, add_confident_as_weight=False, replace_underscore=req.replace_underscore, replace_underscore_excludes=req.replace_underscore_excludes, escape_tag=req.escape_tag, unload_model_after_running=True ) return APIResponseSuccess() @router.get("/pick_file") async def pick_file(picker_type: str): if picker_type == "folder": coro = asyncio.to_thread(open_directory_selector, "") elif picker_type == "modelfile": file_types = [("checkpoints", "*.safetensors;*.ckpt;*.pt"), ("all files", "*.*")] coro = asyncio.to_thread(open_file_selector, "", "Select file", file_types) result = await coro if result == "": return APIResponseFail(message="用户取消选择") return APIResponseSuccess(data={ "path": result }) @router.get("/tasks", response_model_exclude_none=True) async def get_tasks() -> APIResponse: return APIResponseSuccess(data={ "tasks": tm.dump() }) @router.get("/tasks/terminate/{task_id}", response_model_exclude_none=True) async def terminate_task(task_id: str): tm.terminate_task(task_id) return APIResponseSuccess() @router.get("/graphic_cards") async def list_avaliable_cards() -> APIResponse: if not printable_devices: return APIResponse(status="pending") return APIResponseSuccess(data={ "cards": printable_devices }) @router.get("/schemas/hashes") async def list_schema_hashes() -> APIResponse: if os.environ.get("MIKAZUKI_SCHEMA_HOT_RELOAD", "0") == "1": log.info("Hot reloading schemas") await load_schemas() return APIResponseSuccess(data={ "schemas": [ { "name": schema["name"], "hash": schema["hash"] } for schema in avaliable_schemas ] }) @router.get("/schemas/all") async def get_all_schemas() -> APIResponse: return APIResponseSuccess(data={ "schemas": avaliable_schemas }) @router.get("/config/saved_params") async def get_saved_params() -> APIResponse: saved_params = app_config["saved_params"] return APIResponseSuccess(data=saved_params)