PBRemTools/scripts/api.py

72 lines
2.5 KiB
Python

from typing import List
import gradio as gr
from PIL import Image
from fastapi import FastAPI
from pydantic import BaseModel, Field
from modules.api.api import encode_pil_to_base64, decode_base64_to_image
from scripts.main import process_image, model_list
class PBRRemRequest(BaseModel):
img: str = Field(..., description='Image to process, base64 encoded')
td_abg_enabled: bool = False
h_split: int = 256
v_split: int = 256
n_cluster: int = 500
alpha: int = 50
th_rate: float = 0.1
cascadePSP_enabled: bool = False
fast: bool = False
psp_L: int = 900
sa_enabled: bool = False
model_name: str = Field('sam_vit_h_4b8939.pth',
description='SAM model name')
query: str = Field('', description='segmentation prompt')
predicted_iou_threshold: float = 0.9
stability_score_threshold: float = 0.9
clip_threshold: float = 0.1
class PBRRemResponse(BaseModel):
img: str
mask: str
def pbrem_api(_: gr.Blocks, app: FastAPI):
@app.get("/pbrem/sam-model", description='query available SAM model', response_model=List[str])
async def get_pbrem_sam_model() -> List[str]:
return model_list
@app.post('/pbrem/predict', description='process image', response_model=PBRRemResponse)
async def post_pbrem_predict(payload: PBRRemRequest) -> PBRRemResponse:
print(f"PBRemTools API /pbrem/predict received request")
input_image = decode_base64_to_image(payload.img)
output_image, mask = process_image(
input_image,
payload.td_abg_enabled, payload.h_split, payload.v_split, payload.n_cluster, payload.alpha, payload.th_rate,
payload.cascadePSP_enabled, payload.fast, payload.psp_L,
payload.sa_enabled, payload.query, payload.model_name, payload.predicted_iou_threshold, payload.stability_score_threshold, payload.clip_threshold
)
base64_img = encode_pil_to_base64(Image.fromarray(output_image))
base64_mask = encode_pil_to_base64(Image.fromarray(mask))
# Compatible: low version webui will return base64 string directly
if not isinstance(base64_img, str):
base64_img = base64_img.decode()
if not isinstance(base64_mask, str):
base64_mask = base64_mask.decode()
return PBRRemResponse(img=base64_img, mask=base64_mask)
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(pbrem_api)
except Exception as e:
print(e)
print("PBRemTools API failed to initialize")