Merge pull request #30 from chace20/main

Add APIs for PBRemTools
pull/35/head
mattya_monaca 2023-05-13 08:11:11 +09:00 committed by GitHub
commit ac8da421e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 78 additions and 0 deletions

View File

@ -61,3 +61,10 @@ Install from webui's Extensions tab.
## Precautions ## Precautions
This program uses code that contains the Apache License 2.0 This program uses code that contains the Apache License 2.0
# API
Currently, there are two APIs available:
- To get the currently available SAM model: `GET http://localhost:7861/pbrem/sam-model`
- To process an image: `POST http://localhost:7861/pbrem/predict`
After launching the web UI API, you can visit the detailed API documentation at `http://localhost:7861/redoc`

71
scripts/api.py Normal file
View File

@ -0,0 +1,71 @@
from typing import List
import gradio as gr
from PIL import Image
from fastapi import FastAPI
from pydantic import BaseModel, Field
from modules.api.api import encode_pil_to_base64, decode_base64_to_image
from scripts.main import process_image, model_list
class PBRRemRequest(BaseModel):
img: str = Field(..., description='Image to process, base64 encoded')
td_abg_enabled: bool = False
h_split: int = 256
v_split: int = 256
n_cluster: int = 500
alpha: int = 50
th_rate: float = 0.1
cascadePSP_enabled: bool = False
fast: bool = False
psp_L: int = 900
sa_enabled: bool = False
model_name: str = Field('sam_vit_h_4b8939.pth',
description='SAM model name')
query: str = Field('', description='segmentation prompt')
predicted_iou_threshold: float = 0.9
stability_score_threshold: float = 0.9
clip_threshold: float = 0.1
class PBRRemResponse(BaseModel):
img: str
mask: str
def pbrem_api(_: gr.Blocks, app: FastAPI):
@app.get("/pbrem/sam-model", description='query available SAM model', response_model=List[str])
async def get_pbrem_sam_model() -> List[str]:
return model_list
@app.post('/pbrem/predict', description='process image', response_model=PBRRemResponse)
async def post_pbrem_predict(payload: PBRRemRequest) -> PBRRemResponse:
print(f"PBRemTools API /pbrem/predict received request")
input_image = decode_base64_to_image(payload.img)
output_image, mask = process_image(
input_image,
payload.td_abg_enabled, payload.h_split, payload.v_split, payload.n_cluster, payload.alpha, payload.th_rate,
payload.cascadePSP_enabled, payload.fast, payload.psp_L,
payload.sa_enabled, payload.query, payload.model_name, payload.predicted_iou_threshold, payload.stability_score_threshold, payload.clip_threshold
)
base64_img = encode_pil_to_base64(Image.fromarray(output_image))
base64_mask = encode_pil_to_base64(Image.fromarray(mask))
# Compatible: low version webui will return base64 string directly
if not isinstance(base64_img, str):
base64_img = base64_img.decode()
if not isinstance(base64_mask, str):
base64_mask = base64_mask.decode()
return PBRRemResponse(img=base64_img, mask=base64_mask)
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(pbrem_api)
except Exception as e:
print(e)
print("PBRemTools API failed to initialize")