lora-scripts/mikazuki/app/proxy.py

86 lines
3.1 KiB
Python

import asyncio
import os
import httpx
import starlette
import websockets
from fastapi import APIRouter, Request, WebSocket
from httpx import ConnectError
from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import PlainTextResponse, StreamingResponse
router = APIRouter()
def reverse_proxy_maker(url_type: str, full_path: bool = False):
if url_type == "tensorboard":
host = os.environ.get("MIKAZUKI_TENSORBOARD_HOST", "127.0.0.1")
port = os.environ.get("MIKAZUKI_TENSORBOARD_PORT", "6006")
client = httpx.AsyncClient(base_url=f"http://{host}:{port}/")
elif url_type == "tageditor":
client = httpx.AsyncClient(base_url="http://127.0.0.1:28001/")
async def _reverse_proxy(request: Request):
if full_path:
url = httpx.URL(path=request.url.path, query=request.url.query.encode("utf-8"))
else:
url = httpx.URL(
path=request.path_params.get("path", ""),
query=request.url.query.encode("utf-8")
)
rp_req = client.build_request(
request.method, url,
headers=request.headers.raw,
content=request.stream() if request.method != "GET" else None
)
try:
rp_resp = await client.send(rp_req, stream=True)
except ConnectError:
return PlainTextResponse(
content="The requested service not started yet or tensorboard started fail.\n请求的服务尚未启动或启动失败。",
status_code=502
)
return StreamingResponse(
rp_resp.aiter_raw(),
status_code=rp_resp.status_code,
headers=rp_resp.headers,
background=BackgroundTask(rp_resp.aclose),
)
return _reverse_proxy
async def proxy_ws_forward(ws_a: WebSocket, ws_b: websockets.WebSocketClientProtocol):
while True:
try:
data = await ws_a.receive_text()
except starlette.websockets.WebSocketDisconnect as e:
break
# print("websocket A received:", data)
await ws_b.send(data)
async def proxy_ws_reverse(ws_a: WebSocket, ws_b: websockets.WebSocketClientProtocol):
while True:
try:
data = await ws_b.recv()
except websockets.exceptions.ConnectionClosedOK as e:
break
await ws_a.send_text(data)
@router.websocket("/proxy/tageditor/queue/join")
async def websocket_a(ws_a: WebSocket):
# for temp use
ws_b_uri = "ws://127.0.0.1:28001/queue/join"
await ws_a.accept()
async with websockets.connect(ws_b_uri, timeout=360) as ws_b_client:
fwd_task = asyncio.create_task(proxy_ws_forward(ws_a, ws_b_client))
rev_task = asyncio.create_task(proxy_ws_reverse(ws_a, ws_b_client))
await asyncio.gather(fwd_task, rev_task)
router.add_route("/proxy/tensorboard/{path:path}", reverse_proxy_maker("tensorboard"), ["GET", "POST"])
router.add_route("/font-roboto/{path:path}", reverse_proxy_maker("tensorboard", full_path=True), ["GET", "POST"])
router.add_route("/proxy/tageditor/{path:path}", reverse_proxy_maker("tageditor"), ["GET", "POST"])