lora-scripts/mikazuki/app/api.py

355 lines
12 KiB
Python

import asyncio
import hashlib
import json
import os
import re
import random
from glob import glob
from datetime import datetime
from pathlib import Path
from typing import Tuple, Optional
import toml
from fastapi import APIRouter, BackgroundTasks, Request
from starlette.requests import Request
import mikazuki.process as process
from mikazuki import launch_utils
from mikazuki.app.config import app_config
from mikazuki.app.models import (APIResponse, APIResponseFail,
APIResponseSuccess, TaggerInterrogateRequest)
from mikazuki.log import log
from mikazuki.tagger.interrogator import (available_interrogators,
on_interrogate)
from mikazuki.tasks import tm
from mikazuki.utils import train_utils
from mikazuki.utils.devices import printable_devices
from mikazuki.utils.tk_window import (open_directory_selector,
open_file_selector)
router = APIRouter()
avaliable_scripts = [
"networks/extract_lora_from_models.py",
"networks/extract_lora_from_dylora.py",
"networks/merge_lora.py",
"tools/merge_models.py",
]
avaliable_schemas = []
avaliable_presets = []
trainer_mapping = {
"sd-lora": "./scripts/stable/train_network.py",
"sdxl-lora": "./scripts/stable/sdxl_train_network.py",
"sd-dreambooth": "./scripts/stable/train_db.py",
"sdxl-finetune": "./scripts/stable/sdxl_train.py",
"sd3-lora": "./scripts/dev/sd3_train_network.py",
"flux-lora": "./scripts/dev/flux_train_network.py",
"flux-finetune": "./scripts/dev/flux_train.py",
}
async def load_schemas():
avaliable_schemas.clear()
schema_dir = os.path.join(os.getcwd(), "mikazuki", "schema")
schemas = os.listdir(schema_dir)
def lambda_hash(x):
return hashlib.md5(x.encode()).hexdigest()
for schema_name in schemas:
with open(os.path.join(schema_dir, schema_name), encoding="utf-8") as f:
content = f.read()
avaliable_schemas.append({
"name": schema_name.rstrip(".ts"),
"schema": content,
"hash": lambda_hash(content)
})
async def load_presets():
avaliable_presets.clear()
preset_dir = os.path.join(os.getcwd(), "config", "presets")
presets = os.listdir(preset_dir)
for preset_name in presets:
with open(os.path.join(preset_dir, preset_name), encoding="utf-8") as f:
content = f.read()
avaliable_presets.append(toml.loads(content))
def get_sample_prompts(config: dict) -> Tuple[Optional[str], str]:
train_data_dir = config["train_data_dir"]
sub_dir = [dir for dir in glob(os.path.join(train_data_dir, '*')) if os.path.isdir(dir)]
positive_prompts = config.pop('positive_prompts', None)
negative_prompts = config.pop('negative_prompts', '')
sample_width = config.pop('sample_width', 512) # 默认宽度 512
sample_height = config.pop('sample_height', 512) # 默认高度 512
sample_cfg = config.pop('sample_cfg', 7) # 默认 CFG 值 7.5
sample_seed = config.pop('sample_seed', 2333) # 默认随机种子 42
sample_steps = config.pop('sample_steps', 24) # 默认步数 50
randomly_choice_prompt = config.pop('randomly_choice_prompt', False)
if randomly_choice_prompt:
if len(sub_dir) != 1:
raise ValueError('训练数据集下有多个子文件夹,无法启用自动选取 Prompt 功能')
# return None, APIResponseFail(message='训练数据集下有多个子文件夹,无法启用自动选取 Prompt 功能')
txt_files = glob(os.path.join(sub_dir[0], '*.txt'))
if not txt_files:
raise ValueError('训练数据集路径没有 txt 文件')
# return None, APIResponseFail(message='训练数据集路径没有 txt 文件')
try:
sample_prompt_file = random.choice(txt_files)
with open(sample_prompt_file, 'r', encoding='utf-8') as f:
positive_prompts = f.read()
except IOError:
# positive_prompts = config['positive_prompts']
pass
return positive_prompts, f'{positive_prompts} --n {negative_prompts} --w {sample_width} --h {sample_height} --l {sample_cfg} --s {sample_steps} --d {sample_seed}'
@router.post("/run")
async def create_toml_file(request: Request):
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
toml_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}.toml")
json_data = await request.body()
config: dict = json.loads(json_data.decode("utf-8"))
train_utils.fix_config_types(config)
gpu_ids = config.pop("gpu_ids", None)
suggest_cpu_threads = 8 if len(train_utils.get_total_images(config["train_data_dir"])) > 200 else 2
model_train_type = config.pop("model_train_type", "sd-lora")
trainer_file = trainer_mapping[model_train_type]
if model_train_type != "sdxl-finetune":
if not train_utils.validate_data_dir(config["train_data_dir"]):
return APIResponseFail(message="训练数据集路径不存在或没有图片,请检查目录。")
validated, message = train_utils.validate_model(config["pretrained_model_name_or_path"], model_train_type)
if not validated:
return APIResponseFail(message=message)
try:
positive_prompt, sample_prompts_arg = get_sample_prompts(config=config)
if positive_prompt is not None and train_utils.is_promopt_like(sample_prompts_arg):
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
with open(sample_prompts_file, "w", encoding="utf-8") as f:
f.write(sample_prompts_arg)
config["sample_prompts"] = sample_prompts_file
log.info(f"Wrote promopts to file {sample_prompts_file}")
except ValueError as e:
return APIResponseFail(message=str(e))
with open(toml_file, "w", encoding="utf-8") as f:
f.write(toml.dumps(config))
result = process.run_train(toml_file, trainer_file, gpu_ids, suggest_cpu_threads)
return result
@router.post("/run_script")
async def run_script(request: Request, background_tasks: BackgroundTasks):
paras = await request.body()
j = json.loads(paras.decode("utf-8"))
script_name = j["script_name"]
if script_name not in avaliable_scripts:
return APIResponseFail(message="Script not found")
del j["script_name"]
result = []
for k, v in j.items():
result.append(f"--{k}")
if not isinstance(v, bool):
value = str(v)
if " " in value:
value = f'"{v}"'
result.append(value)
script_args = " ".join(result)
script_path = Path(os.getcwd()) / "scripts" / script_name
cmd = f"{launch_utils.python_bin} {script_path} {script_args}"
background_tasks.add_task(launch_utils.run, cmd)
return APIResponseSuccess()
@router.post("/interrogate")
async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: BackgroundTasks):
interrogator = available_interrogators.get(req.interrogator_model, available_interrogators["wd14-convnextv2-v2"])
background_tasks.add_task(
on_interrogate,
image=None,
batch_input_glob=req.path,
batch_input_recursive=req.batch_input_recursive,
batch_output_dir="",
batch_output_filename_format="[name].[output_extension]",
batch_output_action_on_conflict=req.batch_output_action_on_conflict,
batch_remove_duplicated_tag=True,
batch_output_save_json=False,
interrogator=interrogator,
threshold=req.threshold,
additional_tags=req.additional_tags,
exclude_tags=req.exclude_tags,
sort_by_alphabetical_order=False,
add_confident_as_weight=False,
replace_underscore=req.replace_underscore,
replace_underscore_excludes=req.replace_underscore_excludes,
escape_tag=req.escape_tag,
unload_model_after_running=True
)
return APIResponseSuccess()
@router.get("/pick_file")
async def pick_file(picker_type: str):
if picker_type == "folder":
coro = asyncio.to_thread(open_directory_selector, "")
elif picker_type == "model-file":
file_types = [("checkpoints", "*.safetensors;*.ckpt;*.pt"), ("all files", "*.*")]
coro = asyncio.to_thread(open_file_selector, "", "Select file", file_types)
result = await coro
if result == "":
return APIResponseFail(message="用户取消选择")
return APIResponseSuccess(data={
"path": result
})
@router.get("/get_files")
async def get_files(pick_type) -> APIResponse:
pick_preset = {
"model-file": {
"type": "file",
"path": "./sd-models",
"filter": "(.safetensors|.ckpt|.pt)"
},
"model-saved-file": {
"type": "file",
"path": "./output",
"filter": "(.safetensors|.ckpt|.pt)"
},
"train-dir": {
"type": "folder",
"path": "./train",
"filter": None
},
}
folder_blacklist = [".ipynb_checkpoints", ".DS_Store"]
def list_path_or_files(preset_info):
path = Path(preset_info["path"])
file_type = preset_info["type"]
regex_filter = preset_info["filter"]
result_list = []
if file_type == "file":
if regex_filter:
pattern = re.compile(regex_filter)
files = [f for f in path.glob("**/*") if f.is_file() and pattern.search(f.name)]
else:
files = [f for f in path.glob("**/*") if f.is_file()]
for file in files:
result_list.append({
"path": str(file.resolve().absolute()).replace("\\", "/"),
"name": file.name,
"size": f"{round(file.stat().st_size / (1024**3),2)} GB"
})
elif file_type == "folder":
folders = [f for f in path.iterdir() if f.is_dir()]
for folder in folders:
if folder.name in folder_blacklist:
continue
result_list.append({
"path": str(folder.resolve().absolute()).replace("\\", "/"),
"name": folder.name,
"size": 0
})
return result_list
if pick_type not in pick_preset:
return APIResponseFail(message="Invalid request")
dirs = list_path_or_files(pick_preset[pick_type])
return APIResponseSuccess(data={
"files": dirs
})
@router.get("/tasks", response_model_exclude_none=True)
async def get_tasks() -> APIResponse:
return APIResponseSuccess(data={
"tasks": tm.dump()
})
@router.get("/tasks/terminate/{task_id}", response_model_exclude_none=True)
async def terminate_task(task_id: str):
tm.terminate_task(task_id)
return APIResponseSuccess()
@router.get("/graphic_cards")
async def list_avaliable_cards() -> APIResponse:
if not printable_devices:
return APIResponse(status="pending")
return APIResponseSuccess(data={
"cards": printable_devices
})
@router.get("/schemas/hashes")
async def list_schema_hashes() -> APIResponse:
if os.environ.get("MIKAZUKI_SCHEMA_HOT_RELOAD", "0") == "1":
log.info("Hot reloading schemas")
await load_schemas()
return APIResponseSuccess(data={
"schemas": [
{
"name": schema["name"],
"hash": schema["hash"]
}
for schema in avaliable_schemas
]
})
@router.get("/schemas/all")
async def get_all_schemas() -> APIResponse:
return APIResponseSuccess(data={
"schemas": avaliable_schemas
})
@router.get("/presets")
async def get_presets() -> APIResponse:
if os.environ.get("MIKAZUKI_SCHEMA_HOT_RELOAD", "0") == "1":
log.info("Hot reloading presets")
await load_presets()
return APIResponseSuccess(data={
"presets": avaliable_presets
})
@router.get("/config/saved_params")
async def get_saved_params() -> APIResponse:
saved_params = app_config["saved_params"]
return APIResponseSuccess(data=saved_params)