95 lines
2.9 KiB
Python
95 lines
2.9 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import webbrowser
|
|
from datetime import datetime
|
|
from threading import Lock
|
|
|
|
import uvicorn
|
|
from fastapi import BackgroundTasks, FastAPI, Request
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
import toml
|
|
|
|
app = FastAPI()
|
|
|
|
lock = Lock()
|
|
|
|
# fix mimetype error in some fucking systems
|
|
sf = StaticFiles(directory="frontend/dist")
|
|
_o_fr = sf.file_response
|
|
def _hooked_file_response(*args, **kwargs):
|
|
full_path = args[0]
|
|
r = _o_fr(*args, **kwargs)
|
|
if full_path.endswith(".js"):
|
|
r.media_type = "application/javascript"
|
|
elif full_path.endswith(".css"):
|
|
r.media_type = "text/css"
|
|
return r
|
|
sf.file_response = _hooked_file_response
|
|
|
|
parser = argparse.ArgumentParser(description="GUI for training network")
|
|
parser.add_argument("--port", type=int, default=28000, help="Port to run the server on")
|
|
|
|
def run_train(toml_path: str):
|
|
print(f"Training started with config file / 训练开始,使用配置文件: {toml_path}")
|
|
args = [
|
|
"accelerate", "launch", "--num_cpu_threads_per_process", "8",
|
|
"./sd-scripts/train_network.py",
|
|
"--config_file", toml_path,
|
|
]
|
|
try:
|
|
result = subprocess.run(args, shell=True, env=os.environ)
|
|
if result.returncode != 0:
|
|
print(f"Training failed / 训练失败")
|
|
else:
|
|
print(f"Training finished / 训练完成")
|
|
except Exception as e:
|
|
print(f"An error occurred when training / 创建训练进程时出现致命错误: {e}")
|
|
finally:
|
|
lock.release()
|
|
|
|
|
|
@app.post("/api/run")
|
|
async def create_toml_file(request: Request, background_tasks: BackgroundTasks):
|
|
acquired = lock.acquire(blocking=False)
|
|
|
|
if not acquired:
|
|
print("Training is already running / 已有正在进行的训练")
|
|
return {"status": "fail", "detail": "Training is already running"}
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
toml_file = f"toml/{timestamp}.toml"
|
|
toml_data = await request.body()
|
|
j = json.loads(toml_data.decode("utf-8"))
|
|
with open(toml_file, "w") as f:
|
|
f.write(toml.dumps(j))
|
|
background_tasks.add_task(run_train, toml_file)
|
|
return {"status": "success"}
|
|
|
|
@app.middleware("http")
|
|
async def add_cache_control_header(request, call_next):
|
|
response = await call_next(request)
|
|
response.headers["Cache-Control"] = "max-age=0"
|
|
return response
|
|
|
|
@app.get("/")
|
|
async def index():
|
|
return FileResponse("./frontend/dist/index.html")
|
|
|
|
|
|
app.mount("/", sf, name="static")
|
|
|
|
if __name__ == "__main__":
|
|
args, _ = parser.parse_known_args()
|
|
print(f"Server started at http://127.0.0.1:{args.port}")
|
|
if sys.platform == "win32":
|
|
# disable triton on windows
|
|
os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1"
|
|
|
|
webbrowser.open(f"http://127.0.0.1:{args.port}")
|
|
uvicorn.run(app, host="127.0.0.1", port=28000, log_level="error")
|