lora-scripts/mikazuki/app.py

191 lines
7.0 KiB
Python

import json
import os
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from threading import Lock
from typing import Optional
import starlette.responses as starlette_responses
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
import mikazuki.utils as utils
import toml
from mikazuki.tasks import tm
from mikazuki.log import log
from mikazuki.models import TaggerInterrogateRequest
from mikazuki.tagger.interrogator import (available_interrogators,
on_interrogate)
app = FastAPI()
lock = Lock()
avaliable_scripts = [
"networks/extract_lora_from_models.py",
"networks/extract_lora_from_dylora.py"
]
# fix mimetype error in some fucking systems
_origin_guess_type = starlette_responses.guess_type
def _hooked_guess_type(*args, **kwargs):
url = args[0]
r = _origin_guess_type(*args, **kwargs)
if url.endswith(".js"):
r = ("application/javascript", None)
elif url.endswith(".css"):
r = ("text/css", None)
return r
starlette_responses.guess_type = _hooked_guess_type
def run_train(toml_path: str,
trainer_file: str = "./sd-scripts/train_network.py",
multi_gpu: bool = False,
cpu_threads: Optional[int] = 2):
log.info(f"Training started with config file / 训练开始,使用配置文件: {toml_path}")
args = [
sys.executable, "-m", "accelerate.commands.launch", "--num_cpu_threads_per_process", str(cpu_threads),
trainer_file,
"--config_file", toml_path,
]
if multi_gpu:
args.insert(3, "--multi_gpu")
try:
result = subprocess.run(args, env=os.environ)
if result.returncode != 0:
log.error(f"Training failed / 训练失败")
else:
log.info(f"Training finished / 训练完成")
except Exception as e:
log.error(f"An error occurred when training / 创建训练进程时出现致命错误: {e}")
finally:
lock.release()
@app.middleware("http")
async def add_cache_control_header(request, call_next):
response = await call_next(request)
response.headers["Cache-Control"] = "max-age=0"
# response.headers["Access-Control-Allow-Origin"] = "*"
return response
@app.post("/api/run")
async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
acquired = lock.acquire(blocking=False)
if not acquired:
log.error("Training is already running / 已有正在进行的训练")
return {"status": "fail", "detail": "已有正在进行的训练"}
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
toml_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}.toml")
toml_data = await request.body()
j = json.loads(toml_data.decode("utf-8"))
# ok = utils.check_training_params(j)
# if not ok:
# lock.release()
# print("训练目录校验失败,请确保填写的目录存在")
# return {"status": "fail", "detail": "训练目录校验失败,请确保填写的目录存在"}
suggest_cpu_threads = 8 if utils.get_total_images(j["train_data_dir"]) > 100 else 2
trainer_file = "./sd-scripts/train_network.py"
if j.pop("model_train_type", "sd-lora") == "sdxl-lora":
trainer_file = "./sd-scripts/sdxl_train_network.py"
multi_gpu = j.pop("multi_gpu", False)
def is_promopt_like(s):
for p in ["--n", "--s", "--l", "--d"]:
if p in s:
return True
return False
sample_prompts = j.get("sample_prompts", None)
if sample_prompts is not None and not os.path.exists(sample_prompts) and is_promopt_like(sample_prompts):
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
with open(sample_prompts_file, "w", encoding="utf-8") as f:
f.write(sample_prompts)
j["sample_prompts"] = sample_prompts_file
log.info(f"Writted promopts to file {sample_prompts_file}")
with open(toml_file, "w") as f:
f.write(toml.dumps(j))
background_tasks.add_task(run_train, toml_file, trainer_file, multi_gpu, suggest_cpu_threads)
return {"status": "success"}
@app.post("/api/run_script")
async def run_script(request: Request, background_tasks: BackgroundTasks):
paras = await request.body()
j = json.loads(paras.decode("utf-8"))
script_name = j["script_name"]
if script_name not in avaliable_scripts:
return {"status": "fail"}
del j["script_name"]
result = []
for k, v in j.items():
result.append(f"--{k}")
if not isinstance(v, bool):
value = str(v)
if " " in value:
value = f'"{v}"'
result.append(value)
script_args = " ".join(result)
script_path = Path(os.getcwd()) / "sd-scripts" / script_name
cmd = f"{utils.python_bin} {script_path} {script_args}"
background_tasks.add_task(utils.run, cmd)
return {"status": "success"}
@app.post("/api/interrogate")
async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: BackgroundTasks):
interrogator = available_interrogators.get(req.interrogator_model, available_interrogators["wd14-convnextv2-v2"])
background_tasks.add_task(on_interrogate,
image=None,
batch_input_glob=req.path,
batch_input_recursive=False,
batch_output_dir="",
batch_output_filename_format="[name].[output_extension]",
batch_output_action_on_conflict=req.batch_output_action_on_conflict,
batch_remove_duplicated_tag=True,
batch_output_save_json=False,
interrogator=interrogator,
threshold=req.threshold,
additional_tags=req.additional_tags,
exclude_tags=req.exclude_tags,
sort_by_alphabetical_order=False,
add_confident_as_weight=False,
replace_underscore=req.replace_underscore,
replace_underscore_excludes=req.replace_underscore_excludes,
escape_tag=req.escape_tag,
unload_model_after_running=True
)
return {"status": "success"}
# @app.get("/api/schema/{name}")
# async def get_schema(name: str):
# with open(os.path.join(os.getcwd(), "mikazuki", "schema", name), encoding="utf-8") as f:
# content = f.read()
# return Response(content=content, media_type="text/plain")
@app.get("/api/tasks")
async def get_tasks():
return tm.dump()
@app.get("/")
async def index():
return FileResponse("./frontend/dist/index.html")
app.mount("/", StaticFiles(directory="frontend/dist"), name="static")