61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
from fastapi import FastAPI, Body
|
|
from io import BytesIO
|
|
import base64
|
|
from pydantic import BaseModel
|
|
from typing import Any
|
|
import asyncio
|
|
import gradio as gr
|
|
import os
|
|
from scripts.sam import init_sam_model, dilate_mask, sam_predict, sam_model_list
|
|
from scripts.dino import dino_model_list
|
|
from PIL import Image, ImageChops
|
|
import base64
|
|
|
|
def sam_api(_: gr.Blocks, app: FastAPI):
|
|
@app.get("/sam-webui/heartbeat")
|
|
async def heartbeat():
|
|
return {
|
|
"msg": "Success!"
|
|
}
|
|
|
|
class MaskRequest(BaseModel):
|
|
image: str #base64 string containing image
|
|
prompt: str
|
|
box_threshold: float
|
|
|
|
def pil_image_to_base64(img: Image.Image) -> str:
|
|
buffered = BytesIO()
|
|
img.save(buffered, format="JPEG")
|
|
img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
|
return img_base64
|
|
|
|
@app.post("/sam-webui/image-mask")
|
|
async def process_image(payload: MaskRequest = Body(...)) -> Any:
|
|
sam_model_name = sam_model_list[0] if len(sam_model_list) > 0 else None
|
|
dino_model_name = dino_model_list[0] if len(dino_model_list) > 0 else None
|
|
# Decode the base64 image string
|
|
img_b64 = base64.b64decode(payload.image)
|
|
input_img = Image.open(BytesIO(img_b64))
|
|
#Run DINO and SAM inference to get masks back
|
|
masks = sam_predict(sam_model_name,
|
|
input_img,
|
|
[],
|
|
[],
|
|
True,
|
|
dino_model_name,
|
|
payload.prompt,
|
|
payload.box_threshold,
|
|
None,
|
|
None,
|
|
gui=False)
|
|
|
|
# Convert the final PIL image to a base64 string
|
|
response = [{"image": pil_image_to_base64(mask)} for mask in masks]
|
|
|
|
return response
|
|
|
|
try:
|
|
import modules.script_callbacks as script_callbacks
|
|
script_callbacks.on_app_started(sam_api)
|
|
except:
|
|
print("SAM Web UI API failed to initialize") |