sd-webui-segment-anything/scripts/api.py

61 lines
2.0 KiB
Python

from fastapi import FastAPI, Body
from io import BytesIO
import base64
from pydantic import BaseModel
from typing import Any
import asyncio
import gradio as gr
import os
from scripts.sam import init_sam_model, dilate_mask, sam_predict, sam_model_list
from scripts.dino import dino_model_list
from PIL import Image, ImageChops
import base64
def sam_api(_: gr.Blocks, app: FastAPI):
@app.get("/sam-webui/heartbeat")
async def heartbeat():
return {
"msg": "Success!"
}
class MaskRequest(BaseModel):
image: str #base64 string containing image
prompt: str
box_threshold: float
def pil_image_to_base64(img: Image.Image) -> str:
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
return img_base64
@app.post("/sam-webui/image-mask")
async def process_image(payload: MaskRequest = Body(...)) -> Any:
sam_model_name = sam_model_list[0] if len(sam_model_list) > 0 else None
dino_model_name = dino_model_list[0] if len(dino_model_list) > 0 else None
# Decode the base64 image string
img_b64 = base64.b64decode(payload.image)
input_img = Image.open(BytesIO(img_b64))
#Run DINO and SAM inference to get masks back
masks = sam_predict(sam_model_name,
input_img,
[],
[],
True,
dino_model_name,
payload.prompt,
payload.box_threshold,
None,
None,
gui=False)
# Convert the final PIL image to a base64 string
response = [{"image": pil_image_to_base64(mask)} for mask in masks]
return response
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(sam_api)
except:
print("SAM Web UI API failed to initialize")