87 lines
2.3 KiB
Python
87 lines
2.3 KiB
Python
import asyncio
|
|
import mimetypes
|
|
import os
|
|
import sys
|
|
import webbrowser
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from starlette.exceptions import HTTPException
|
|
|
|
from mikazuki.app.config import app_config
|
|
from mikazuki.app.api import load_schemas
|
|
from mikazuki.app.api import router as api_router
|
|
# from mikazuki.app.ipc import router as ipc_router
|
|
from mikazuki.app.proxy import router as proxy_router
|
|
from mikazuki.utils.devices import check_torch_gpu
|
|
|
|
mimetypes.add_type("application/javascript", ".js")
|
|
mimetypes.add_type("text/css", ".css")
|
|
|
|
|
|
class SPAStaticFiles(StaticFiles):
|
|
async def get_response(self, path: str, scope):
|
|
try:
|
|
return await super().get_response(path, scope)
|
|
except HTTPException as ex:
|
|
if ex.status_code == 404:
|
|
return await super().get_response("index.html", scope)
|
|
else:
|
|
raise ex
|
|
|
|
|
|
async def app_startup():
|
|
app_config.load_config()
|
|
|
|
await load_schemas()
|
|
await asyncio.to_thread(check_torch_gpu)
|
|
|
|
if sys.platform == "win32" and os.environ.get("MIKAZUKI_DEV", "0") != "1":
|
|
webbrowser.open(f'http://{os.environ["MIKAZUKI_HOST"]}:{os.environ["MIKAZUKI_PORT"]}')
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
await app_startup()
|
|
yield
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
app.include_router(proxy_router)
|
|
|
|
|
|
cors_config = os.environ.get("MIKAZUKI_APP_CORS", "")
|
|
if cors_config != "":
|
|
if cors_config == "1":
|
|
cors_config = ["http://localhost:8004", "*"]
|
|
else:
|
|
cors_config = cors_config.split(";")
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=cors_config,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.middleware("http")
|
|
async def add_cache_control_header(request, call_next):
|
|
response = await call_next(request)
|
|
response.headers["Cache-Control"] = "max-age=0"
|
|
return response
|
|
|
|
app.include_router(api_router, prefix="/api")
|
|
# app.include_router(ipc_router, prefix="/ipc")
|
|
|
|
|
|
@app.get("/")
|
|
async def index():
|
|
return FileResponse("./frontend/dist/index.html")
|
|
|
|
|
|
app.mount("/", SPAStaticFiles(directory="frontend/dist", html=True), name="static")
|