lora-scripts/mikazuki/app/api.py

289 lines
9.5 KiB
Python

import asyncio
import hashlib
import json
import os
import re
from datetime import datetime
from pathlib import Path
import toml
from fastapi import APIRouter, BackgroundTasks, Request
from starlette.requests import Request
import mikazuki.process as process
from mikazuki import launch_utils
from mikazuki.app.config import app_config
from mikazuki.app.models import (APIResponse, APIResponseFail,
APIResponseSuccess, TaggerInterrogateRequest)
from mikazuki.log import log
from mikazuki.tagger.interrogator import (available_interrogators,
on_interrogate)
from mikazuki.tasks import tm
from mikazuki.utils import train_utils
from mikazuki.utils.devices import printable_devices
from mikazuki.utils.tk_window import (open_directory_selector,
open_file_selector)
router = APIRouter()
avaliable_scripts = [
"networks/extract_lora_from_models.py",
"networks/extract_lora_from_dylora.py",
"networks/merge_lora.py",
"tools/merge_models.py",
]
avaliable_schemas = []
trainer_mapping = {
"sd-lora": "./scripts/stable/train_network.py",
"sdxl-lora": "./scripts/stable/sdxl_train_network.py",
"sd-dreambooth": "./scripts/stable/train_db.py",
"sdxl-finetune": "./scripts/stable/sdxl_train.py",
"sd3-lora": "./scripts/dev/sd3_train_network.py",
"flux-lora": "./scripts/dev/flux_train_network.py",
"flux-finetune": "./scripts/dev/flux_train.py",
}
async def load_schemas():
avaliable_schemas.clear()
schema_dir = os.path.join(os.getcwd(), "mikazuki", "schema")
schemas = os.listdir(schema_dir)
def lambda_hash(x):
return hashlib.md5(x.encode()).hexdigest()
for schema_name in schemas:
with open(os.path.join(schema_dir, schema_name), encoding="utf-8") as f:
content = f.read()
avaliable_schemas.append({
"name": schema_name.rstrip(".ts"),
"schema": content,
"hash": lambda_hash(content)
})
@router.post("/run")
async def create_toml_file(request: Request):
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
toml_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}.toml")
json_data = await request.body()
config: dict = json.loads(json_data.decode("utf-8"))
train_utils.fix_config_types(config)
gpu_ids = config.pop("gpu_ids", None)
suggest_cpu_threads = 8 if len(train_utils.get_total_images(config["train_data_dir"])) > 200 else 2
model_train_type = config.pop("model_train_type", "sd-lora")
trainer_file = trainer_mapping[model_train_type]
if model_train_type != "sdxl-finetune":
if not train_utils.validate_data_dir(config["train_data_dir"]):
return APIResponseFail(message="训练数据集路径不存在或没有图片,请检查目录。")
validated, message = train_utils.validate_model(config["pretrained_model_name_or_path"], model_train_type)
if not validated:
return APIResponseFail(message=message)
sample_prompts = config.get("sample_prompts", None)
if sample_prompts is not None and not os.path.exists(sample_prompts) and train_utils.is_promopt_like(sample_prompts):
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
with open(sample_prompts_file, "w", encoding="utf-8") as f:
f.write(sample_prompts)
config["sample_prompts"] = sample_prompts_file
log.info(f"Wrote promopts to file {sample_prompts_file}")
with open(toml_file, "w", encoding="utf-8") as f:
f.write(toml.dumps(config))
result = process.run_train(toml_file, trainer_file, gpu_ids, suggest_cpu_threads)
return result
@router.post("/run_script")
async def run_script(request: Request, background_tasks: BackgroundTasks):
paras = await request.body()
j = json.loads(paras.decode("utf-8"))
script_name = j["script_name"]
if script_name not in avaliable_scripts:
return APIResponseFail(message="Script not found")
del j["script_name"]
result = []
for k, v in j.items():
result.append(f"--{k}")
if not isinstance(v, bool):
value = str(v)
if " " in value:
value = f'"{v}"'
result.append(value)
script_args = " ".join(result)
script_path = Path(os.getcwd()) / "scripts" / script_name
cmd = f"{launch_utils.python_bin} {script_path} {script_args}"
background_tasks.add_task(launch_utils.run, cmd)
return APIResponseSuccess()
@router.post("/interrogate")
async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: BackgroundTasks):
interrogator = available_interrogators.get(req.interrogator_model, available_interrogators["wd14-convnextv2-v2"])
background_tasks.add_task(
on_interrogate,
image=None,
batch_input_glob=req.path,
batch_input_recursive=req.batch_input_recursive,
batch_output_dir="",
batch_output_filename_format="[name].[output_extension]",
batch_output_action_on_conflict=req.batch_output_action_on_conflict,
batch_remove_duplicated_tag=True,
batch_output_save_json=False,
interrogator=interrogator,
threshold=req.threshold,
additional_tags=req.additional_tags,
exclude_tags=req.exclude_tags,
sort_by_alphabetical_order=False,
add_confident_as_weight=False,
replace_underscore=req.replace_underscore,
replace_underscore_excludes=req.replace_underscore_excludes,
escape_tag=req.escape_tag,
unload_model_after_running=True
)
return APIResponseSuccess()
@router.get("/pick_file")
async def pick_file(picker_type: str):
if picker_type == "folder":
coro = asyncio.to_thread(open_directory_selector, "")
elif picker_type == "model-file":
file_types = [("checkpoints", "*.safetensors;*.ckpt;*.pt"), ("all files", "*.*")]
coro = asyncio.to_thread(open_file_selector, "", "Select file", file_types)
result = await coro
if result == "":
return APIResponseFail(message="用户取消选择")
return APIResponseSuccess(data={
"path": result
})
@router.get("/get_files")
async def get_files(pick_type) -> APIResponse:
pick_preset = {
"model-file": {
"type": "file",
"path": "./sd-models",
"filter": "(.safetensors|.ckpt|.pt)"
},
"model-saved-file": {
"type": "file",
"path": "./output",
"filter": "(.safetensors|.ckpt|.pt)"
},
"train-dir": {
"type": "folder",
"path": "./train",
"filter": None
},
}
folder_blacklist = [".ipynb_checkpoints", ".DS_Store"]
def list_path_or_files(preset_info):
path = Path(preset_info["path"])
file_type = preset_info["type"]
regex_filter = preset_info["filter"]
result_list = []
if file_type == "file":
if regex_filter:
pattern = re.compile(regex_filter)
files = [f for f in path.glob("**/*") if f.is_file() and pattern.search(f.name)]
else:
files = [f for f in path.glob("**/*") if f.is_file()]
for file in files:
result_list.append({
"path": str(file.resolve().absolute()).replace("\\", "/"),
"name": file.name,
"size": f"{round(file.stat().st_size / (1024**3),2)} GB"
})
elif file_type == "folder":
folders = [f for f in path.iterdir() if f.is_dir()]
for folder in folders:
if folder.name in folder_blacklist:
continue
result_list.append({
"path": str(folder.resolve().absolute()).replace("\\", "/"),
"name": folder.name,
"size": 0
})
return result_list
if pick_type not in pick_preset:
return APIResponseFail(message="Invalid request")
dirs = list_path_or_files(pick_preset[pick_type])
return APIResponseSuccess(data={
"files": dirs
})
@router.get("/tasks", response_model_exclude_none=True)
async def get_tasks() -> APIResponse:
return APIResponseSuccess(data={
"tasks": tm.dump()
})
@router.get("/tasks/terminate/{task_id}", response_model_exclude_none=True)
async def terminate_task(task_id: str):
tm.terminate_task(task_id)
return APIResponseSuccess()
@router.get("/graphic_cards")
async def list_avaliable_cards() -> APIResponse:
if not printable_devices:
return APIResponse(status="pending")
return APIResponseSuccess(data={
"cards": printable_devices
})
@router.get("/schemas/hashes")
async def list_schema_hashes() -> APIResponse:
if os.environ.get("MIKAZUKI_SCHEMA_HOT_RELOAD", "0") == "1":
log.info("Hot reloading schemas")
await load_schemas()
return APIResponseSuccess(data={
"schemas": [
{
"name": schema["name"],
"hash": schema["hash"]
}
for schema in avaliable_schemas
]
})
@router.get("/schemas/all")
async def get_all_schemas() -> APIResponse:
return APIResponseSuccess(data={
"schemas": avaliable_schemas
})
@router.get("/config/saved_params")
async def get_saved_params() -> APIResponse:
saved_params = app_config["saved_params"]
return APIResponseSuccess(data=saved_params)