import asyncio import hashlib import json import os import re import random from glob import glob from datetime import datetime from pathlib import Path from typing import Tuple, Optional import toml from fastapi import APIRouter, BackgroundTasks, Request from starlette.requests import Request import mikazuki.process as process from mikazuki import launch_utils from mikazuki.app.config import app_config from mikazuki.app.models import (APIResponse, APIResponseFail, APIResponseSuccess, TaggerInterrogateRequest) from mikazuki.log import log from mikazuki.tagger.interrogator import (available_interrogators, on_interrogate) from mikazuki.tasks import tm from mikazuki.utils import train_utils from mikazuki.utils.devices import printable_devices from mikazuki.utils.tk_window import (open_directory_selector, open_file_selector) router = APIRouter() avaliable_scripts = [ "networks/extract_lora_from_models.py", "networks/extract_lora_from_dylora.py", "networks/merge_lora.py", "tools/merge_models.py", ] avaliable_schemas = [] avaliable_presets = [] trainer_mapping = { "sd-lora": "./scripts/stable/train_network.py", "sdxl-lora": "./scripts/stable/sdxl_train_network.py", "sd-dreambooth": "./scripts/stable/train_db.py", "sdxl-finetune": "./scripts/stable/sdxl_train.py", "sd3-lora": "./scripts/dev/sd3_train_network.py", "flux-lora": "./scripts/dev/flux_train_network.py", "flux-finetune": "./scripts/dev/flux_train.py", } async def load_schemas(): avaliable_schemas.clear() schema_dir = os.path.join(os.getcwd(), "mikazuki", "schema") schemas = os.listdir(schema_dir) def lambda_hash(x): return hashlib.md5(x.encode()).hexdigest() for schema_name in schemas: with open(os.path.join(schema_dir, schema_name), encoding="utf-8") as f: content = f.read() avaliable_schemas.append({ "name": schema_name.rstrip(".ts"), "schema": content, "hash": lambda_hash(content) }) async def load_presets(): avaliable_presets.clear() preset_dir = os.path.join(os.getcwd(), "config", "presets") presets = os.listdir(preset_dir) for preset_name in presets: with open(os.path.join(preset_dir, preset_name), encoding="utf-8") as f: content = f.read() avaliable_presets.append(toml.loads(content)) def get_sample_prompts(config: dict) -> Tuple[Optional[str], str]: # backward compatibility if "sample_prompts" in config and "positive_prompts" not in config: return None, config["sample_prompts"] train_data_dir = config["train_data_dir"] sub_dir = [dir for dir in glob(os.path.join(train_data_dir, '*')) if os.path.isdir(dir)] positive_prompts = config.pop('positive_prompts', None) negative_prompts = config.pop('negative_prompts', '') sample_width = config.pop('sample_width', 512) sample_height = config.pop('sample_height', 512) sample_cfg = config.pop('sample_cfg', 7) sample_seed = config.pop('sample_seed', 2333) sample_steps = config.pop('sample_steps', 24) randomly_choice_prompt = config.pop('randomly_choice_prompt', False) if randomly_choice_prompt: if len(sub_dir) != 1: raise ValueError('训练数据集下有多个子文件夹,无法启用随机选取 Prompt 功能') txt_files = glob(os.path.join(sub_dir[0], '*.txt')) if not txt_files: raise ValueError('训练数据集路径没有 txt 文件') try: sample_prompt_file = random.choice(txt_files) with open(sample_prompt_file, 'r', encoding='utf-8') as f: positive_prompts = f.read() except IOError: log.error(f"读取 {sample_prompt_file} 文件失败") return positive_prompts, f'{positive_prompts} --n {negative_prompts} --w {sample_width} --h {sample_height} --l {sample_cfg} --s {sample_steps} --d {sample_seed}' @router.post("/run") async def create_toml_file(request: Request): timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") toml_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}.toml") json_data = await request.body() config: dict = json.loads(json_data.decode("utf-8")) train_utils.fix_config_types(config) gpu_ids = config.pop("gpu_ids", None) suggest_cpu_threads = 8 if len(train_utils.get_total_images(config["train_data_dir"])) > 200 else 2 model_train_type = config.pop("model_train_type", "sd-lora") trainer_file = trainer_mapping[model_train_type] if model_train_type != "sdxl-finetune": if not train_utils.validate_data_dir(config["train_data_dir"]): return APIResponseFail(message="训练数据集路径不存在或没有图片,请检查目录。") validated, message = train_utils.validate_model(config["pretrained_model_name_or_path"], model_train_type) if not validated: return APIResponseFail(message=message) try: positive_prompt, sample_prompts_arg = get_sample_prompts(config=config) if positive_prompt is not None and train_utils.is_promopt_like(sample_prompts_arg): sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt") with open(sample_prompts_file, "w", encoding="utf-8") as f: f.write(sample_prompts_arg) config["sample_prompts"] = sample_prompts_file log.info(f"Wrote prompts to file {sample_prompts_file}") except ValueError as e: log.error(f"Error while processing prompts: {e}") return APIResponseFail(message=str(e)) with open(toml_file, "w", encoding="utf-8") as f: f.write(toml.dumps(config)) result = process.run_train(toml_file, trainer_file, gpu_ids, suggest_cpu_threads) return result @router.post("/run_script") async def run_script(request: Request, background_tasks: BackgroundTasks): paras = await request.body() j = json.loads(paras.decode("utf-8")) script_name = j["script_name"] if script_name not in avaliable_scripts: return APIResponseFail(message="Script not found") del j["script_name"] result = [] for k, v in j.items(): result.append(f"--{k}") if not isinstance(v, bool): value = str(v) if " " in value: value = f'"{v}"' result.append(value) script_args = " ".join(result) script_path = Path(os.getcwd()) / "scripts" / script_name cmd = f"{launch_utils.python_bin} {script_path} {script_args}" background_tasks.add_task(launch_utils.run, cmd) return APIResponseSuccess() @router.post("/interrogate") async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: BackgroundTasks): interrogator = available_interrogators.get(req.interrogator_model, available_interrogators["wd14-convnextv2-v2"]) background_tasks.add_task( on_interrogate, image=None, batch_input_glob=req.path, batch_input_recursive=req.batch_input_recursive, batch_output_dir="", batch_output_filename_format="[name].[output_extension]", batch_output_action_on_conflict=req.batch_output_action_on_conflict, batch_remove_duplicated_tag=True, batch_output_save_json=False, interrogator=interrogator, threshold=req.threshold, additional_tags=req.additional_tags, exclude_tags=req.exclude_tags, sort_by_alphabetical_order=False, add_confident_as_weight=False, replace_underscore=req.replace_underscore, replace_underscore_excludes=req.replace_underscore_excludes, escape_tag=req.escape_tag, unload_model_after_running=True ) return APIResponseSuccess() @router.get("/pick_file") async def pick_file(picker_type: str): if picker_type == "folder": coro = asyncio.to_thread(open_directory_selector, "") elif picker_type == "model-file": file_types = [("checkpoints", "*.safetensors;*.ckpt;*.pt"), ("all files", "*.*")] coro = asyncio.to_thread(open_file_selector, "", "Select file", file_types) result = await coro if result == "": return APIResponseFail(message="用户取消选择") return APIResponseSuccess(data={ "path": result }) @router.get("/get_files") async def get_files(pick_type) -> APIResponse: pick_preset = { "model-file": { "type": "file", "path": "./sd-models", "filter": "(.safetensors|.ckpt|.pt)" }, "model-saved-file": { "type": "file", "path": "./output", "filter": "(.safetensors|.ckpt|.pt)" }, "train-dir": { "type": "folder", "path": "./train", "filter": None }, } folder_blacklist = [".ipynb_checkpoints", ".DS_Store"] def list_path_or_files(preset_info): path = Path(preset_info["path"]) file_type = preset_info["type"] regex_filter = preset_info["filter"] result_list = [] if file_type == "file": if regex_filter: pattern = re.compile(regex_filter) files = [f for f in path.glob("**/*") if f.is_file() and pattern.search(f.name)] else: files = [f for f in path.glob("**/*") if f.is_file()] for file in files: result_list.append({ "path": str(file.resolve().absolute()).replace("\\", "/"), "name": file.name, "size": f"{round(file.stat().st_size / (1024**3),2)} GB" }) elif file_type == "folder": folders = [f for f in path.iterdir() if f.is_dir()] for folder in folders: if folder.name in folder_blacklist: continue result_list.append({ "path": str(folder.resolve().absolute()).replace("\\", "/"), "name": folder.name, "size": 0 }) return result_list if pick_type not in pick_preset: return APIResponseFail(message="Invalid request") dirs = list_path_or_files(pick_preset[pick_type]) return APIResponseSuccess(data={ "files": dirs }) @router.get("/tasks", response_model_exclude_none=True) async def get_tasks() -> APIResponse: return APIResponseSuccess(data={ "tasks": tm.dump() }) @router.get("/tasks/terminate/{task_id}", response_model_exclude_none=True) async def terminate_task(task_id: str): tm.terminate_task(task_id) return APIResponseSuccess() @router.get("/graphic_cards") async def list_avaliable_cards() -> APIResponse: if not printable_devices: return APIResponse(status="pending") return APIResponseSuccess(data={ "cards": printable_devices }) @router.get("/schemas/hashes") async def list_schema_hashes() -> APIResponse: if os.environ.get("MIKAZUKI_SCHEMA_HOT_RELOAD", "0") == "1": log.info("Hot reloading schemas") await load_schemas() return APIResponseSuccess(data={ "schemas": [ { "name": schema["name"], "hash": schema["hash"] } for schema in avaliable_schemas ] }) @router.get("/schemas/all") async def get_all_schemas() -> APIResponse: return APIResponseSuccess(data={ "schemas": avaliable_schemas }) @router.get("/presets") async def get_presets() -> APIResponse: if os.environ.get("MIKAZUKI_SCHEMA_HOT_RELOAD", "0") == "1": log.info("Hot reloading presets") await load_presets() return APIResponseSuccess(data={ "presets": avaliable_presets }) @router.get("/config/saved_params") async def get_saved_params() -> APIResponse: saved_params = app_config["saved_params"] return APIResponseSuccess(data=saved_params)