223 lines
7.4 KiB
Python
223 lines
7.4 KiB
Python
import asyncio
|
|
import hashlib
|
|
import json
|
|
import os
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import toml
|
|
from fastapi import APIRouter, BackgroundTasks, Request
|
|
from starlette.requests import Request
|
|
|
|
import mikazuki.process as process
|
|
from mikazuki import launch_utils
|
|
from mikazuki.app.config import app_config
|
|
from mikazuki.app.models import (APIResponse, APIResponseFail,
|
|
APIResponseSuccess, TaggerInterrogateRequest)
|
|
from mikazuki.log import log
|
|
from mikazuki.tagger.interrogator import (available_interrogators,
|
|
on_interrogate)
|
|
from mikazuki.tasks import tm
|
|
from mikazuki.utils import train_utils
|
|
from mikazuki.utils.devices import printable_devices
|
|
from mikazuki.utils.tk_window import (open_directory_selector,
|
|
open_file_selector)
|
|
|
|
router = APIRouter()
|
|
|
|
avaliable_scripts = [
|
|
"networks/extract_lora_from_models.py",
|
|
"networks/extract_lora_from_dylora.py",
|
|
"networks/merge_lora.py",
|
|
"tools/merge_models.py",
|
|
]
|
|
|
|
avaliable_schemas = []
|
|
|
|
trainer_mapping = {
|
|
"sd-lora": "./scripts/train_network.py",
|
|
"sdxl-lora": "./scripts/sdxl_train_network.py",
|
|
"sd3-lora": "./scripts/sd3_train_network.py",
|
|
"flux-lora": "./scripts/flux_train_network.py",
|
|
"sd-dreambooth": "./scripts/train_db.py",
|
|
"sdxl-finetune": "./scripts/sdxl_train.py",
|
|
}
|
|
|
|
|
|
async def load_schemas():
|
|
avaliable_schemas.clear()
|
|
|
|
schema_dir = os.path.join(os.getcwd(), "mikazuki", "schema")
|
|
schemas = os.listdir(schema_dir)
|
|
|
|
def lambda_hash(x):
|
|
return hashlib.md5(x.encode()).hexdigest()
|
|
|
|
for schema_name in schemas:
|
|
with open(os.path.join(schema_dir, schema_name), encoding="utf-8") as f:
|
|
content = f.read()
|
|
avaliable_schemas.append({
|
|
"name": schema_name.strip(".ts"),
|
|
"schema": content,
|
|
"hash": lambda_hash(content)
|
|
})
|
|
|
|
|
|
@router.post("/run")
|
|
async def create_toml_file(request: Request):
|
|
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
toml_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}.toml")
|
|
json_data = await request.body()
|
|
|
|
config: dict = json.loads(json_data.decode("utf-8"))
|
|
train_utils.fix_config_types(config)
|
|
|
|
gpu_ids = config.pop("gpu_ids", None)
|
|
|
|
suggest_cpu_threads = 8 if len(train_utils.get_total_images(config["train_data_dir"])) > 200 else 2
|
|
model_train_type = config.pop("model_train_type", "sd-lora")
|
|
trainer_file = trainer_mapping[model_train_type]
|
|
|
|
if model_train_type != "sdxl-finetune":
|
|
if not train_utils.validate_data_dir(config["train_data_dir"]):
|
|
return APIResponseFail(message="训练数据集路径不存在或没有图片,请检查目录。")
|
|
|
|
validated, message = train_utils.validate_model(config["pretrained_model_name_or_path"], model_train_type)
|
|
if not validated:
|
|
return APIResponseFail(message=message)
|
|
|
|
sample_prompts = config.get("sample_prompts", None)
|
|
if sample_prompts is not None and not os.path.exists(sample_prompts) and train_utils.is_promopt_like(sample_prompts):
|
|
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
|
|
with open(sample_prompts_file, "w", encoding="utf-8") as f:
|
|
f.write(sample_prompts)
|
|
config["sample_prompts"] = sample_prompts_file
|
|
log.info(f"Wrote promopts to file {sample_prompts_file}")
|
|
|
|
with open(toml_file, "w", encoding="utf-8") as f:
|
|
f.write(toml.dumps(config))
|
|
|
|
result = process.run_train(toml_file, trainer_file, gpu_ids, suggest_cpu_threads)
|
|
|
|
return result
|
|
|
|
|
|
@router.post("/run_script")
|
|
async def run_script(request: Request, background_tasks: BackgroundTasks):
|
|
paras = await request.body()
|
|
j = json.loads(paras.decode("utf-8"))
|
|
script_name = j["script_name"]
|
|
if script_name not in avaliable_scripts:
|
|
return APIResponseFail(message="Script not found")
|
|
del j["script_name"]
|
|
result = []
|
|
for k, v in j.items():
|
|
result.append(f"--{k}")
|
|
if not isinstance(v, bool):
|
|
value = str(v)
|
|
if " " in value:
|
|
value = f'"{v}"'
|
|
result.append(value)
|
|
script_args = " ".join(result)
|
|
script_path = Path(os.getcwd()) / "scripts" / script_name
|
|
cmd = f"{launch_utils.python_bin} {script_path} {script_args}"
|
|
background_tasks.add_task(launch_utils.run, cmd)
|
|
return APIResponseSuccess()
|
|
|
|
|
|
@router.post("/interrogate")
|
|
async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: BackgroundTasks):
|
|
interrogator = available_interrogators.get(req.interrogator_model, available_interrogators["wd14-convnextv2-v2"])
|
|
background_tasks.add_task(
|
|
on_interrogate,
|
|
image=None,
|
|
batch_input_glob=req.path,
|
|
batch_input_recursive=req.batch_input_recursive,
|
|
batch_output_dir="",
|
|
batch_output_filename_format="[name].[output_extension]",
|
|
batch_output_action_on_conflict=req.batch_output_action_on_conflict,
|
|
batch_remove_duplicated_tag=True,
|
|
batch_output_save_json=False,
|
|
interrogator=interrogator,
|
|
threshold=req.threshold,
|
|
additional_tags=req.additional_tags,
|
|
exclude_tags=req.exclude_tags,
|
|
sort_by_alphabetical_order=False,
|
|
add_confident_as_weight=False,
|
|
replace_underscore=req.replace_underscore,
|
|
replace_underscore_excludes=req.replace_underscore_excludes,
|
|
escape_tag=req.escape_tag,
|
|
unload_model_after_running=True
|
|
)
|
|
return APIResponseSuccess()
|
|
|
|
|
|
@router.get("/pick_file")
|
|
async def pick_file(picker_type: str):
|
|
if picker_type == "folder":
|
|
coro = asyncio.to_thread(open_directory_selector, "")
|
|
elif picker_type == "modelfile":
|
|
file_types = [("checkpoints", "*.safetensors;*.ckpt;*.pt"), ("all files", "*.*")]
|
|
coro = asyncio.to_thread(open_file_selector, "", "Select file", file_types)
|
|
|
|
result = await coro
|
|
if result == "":
|
|
return APIResponseFail(message="用户取消选择")
|
|
|
|
return APIResponseSuccess(data={
|
|
"path": result
|
|
})
|
|
|
|
|
|
@router.get("/tasks", response_model_exclude_none=True)
|
|
async def get_tasks() -> APIResponse:
|
|
return APIResponseSuccess(data={
|
|
"tasks": tm.dump()
|
|
})
|
|
|
|
|
|
@router.get("/tasks/terminate/{task_id}", response_model_exclude_none=True)
|
|
async def terminate_task(task_id: str):
|
|
tm.terminate_task(task_id)
|
|
return APIResponseSuccess()
|
|
|
|
|
|
@router.get("/graphic_cards")
|
|
async def list_avaliable_cards() -> APIResponse:
|
|
if not printable_devices:
|
|
return APIResponse(status="pending")
|
|
|
|
return APIResponseSuccess(data={
|
|
"cards": printable_devices
|
|
})
|
|
|
|
|
|
@router.get("/schemas/hashes")
|
|
async def list_schema_hashes() -> APIResponse:
|
|
if os.environ.get("MIKAZUKI_SCHEMA_HOT_RELOAD", "0") == "1":
|
|
log.info("Hot reloading schemas")
|
|
await load_schemas()
|
|
|
|
return APIResponseSuccess(data={
|
|
"schemas": [
|
|
{
|
|
"name": schema["name"],
|
|
"hash": schema["hash"]
|
|
}
|
|
for schema in avaliable_schemas
|
|
]
|
|
})
|
|
|
|
|
|
@router.get("/schemas/all")
|
|
async def get_all_schemas() -> APIResponse:
|
|
return APIResponseSuccess(data={
|
|
"schemas": avaliable_schemas
|
|
})
|
|
|
|
|
|
@router.get("/config/saved_params")
|
|
async def get_saved_params() -> APIResponse:
|
|
saved_params = app_config["saved_params"]
|
|
return APIResponseSuccess(data=saved_params)
|