246 lines
9.6 KiB
Python
246 lines
9.6 KiB
Python
import os
|
|
from fastapi import FastAPI, Body
|
|
from pydantic import BaseModel
|
|
from typing import Any, Optional, List
|
|
import gradio as gr
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
from modules.api.api import encode_pil_to_base64, decode_base64_to_image
|
|
from scripts.sam import sam_predict, dino_predict, update_mask, cnet_seg, categorical_mask
|
|
from scripts.sam import sam_model_list
|
|
|
|
|
|
def decode_to_pil(image):
|
|
if os.path.exists(image):
|
|
return Image.open(image)
|
|
elif type(image) is str:
|
|
return decode_base64_to_image(image)
|
|
elif type(image) is Image.Image:
|
|
return image
|
|
elif type(image) is np.ndarray:
|
|
return Image.fromarray(image)
|
|
else:
|
|
Exception("Not an image")
|
|
|
|
|
|
def encode_to_base64(image):
|
|
if type(image) is str:
|
|
return image
|
|
elif type(image) is Image.Image:
|
|
return encode_pil_to_base64(image).decode()
|
|
elif type(image) is np.ndarray:
|
|
pil = Image.fromarray(image)
|
|
return encode_pil_to_base64(pil).decode()
|
|
else:
|
|
Exception("Invalid type")
|
|
|
|
|
|
def sam_api(_: gr.Blocks, app: FastAPI):
|
|
@app.get("/sam/heartbeat")
|
|
async def heartbeat():
|
|
return {
|
|
"msg": "Success!"
|
|
}
|
|
|
|
@app.get("/sam/sam-model", description='Query available SAM model')
|
|
async def api_sam_model() -> List[str]:
|
|
return sam_model_list
|
|
|
|
class SamPredictRequest(BaseModel):
|
|
sam_model_name: str = "sam_vit_h_4b8939.pth"
|
|
input_image: str
|
|
sam_positive_points: List[List[float]] = []
|
|
sam_negative_points: List[List[float]] = []
|
|
dino_enabled: bool = False
|
|
dino_model_name: Optional[str] = "GroundingDINO_SwinT_OGC (694MB)"
|
|
dino_text_prompt: Optional[str] = None
|
|
dino_box_threshold: Optional[float] = 0.3
|
|
dino_preview_checkbox: bool = False
|
|
dino_preview_boxes_selection: Optional[List[int]] = None
|
|
|
|
@app.post("/sam/sam-predict")
|
|
async def api_sam_predict(payload: SamPredictRequest = Body(...)) -> Any:
|
|
print(f"SAM API /sam/sam-predict received request")
|
|
payload.input_image = decode_to_pil(payload.input_image).convert('RGBA')
|
|
sam_output_mask_gallery, sam_message = sam_predict(
|
|
payload.sam_model_name,
|
|
payload.input_image,
|
|
payload.sam_positive_points,
|
|
payload.sam_negative_points,
|
|
payload.dino_enabled,
|
|
payload.dino_model_name,
|
|
payload.dino_text_prompt,
|
|
payload.dino_box_threshold,
|
|
payload.dino_preview_checkbox,
|
|
payload.dino_preview_boxes_selection)
|
|
print(f"SAM API /sam/sam-predict finished with message: {sam_message}")
|
|
result = {
|
|
"msg": sam_message,
|
|
}
|
|
if len(sam_output_mask_gallery) == 9:
|
|
result["blended_images"] = list(map(encode_to_base64, sam_output_mask_gallery[:3]))
|
|
result["masks"] = list(map(encode_to_base64, sam_output_mask_gallery[3:6]))
|
|
result["masked_images"] = list(map(encode_to_base64, sam_output_mask_gallery[6:]))
|
|
return result
|
|
|
|
class DINOPredictRequest(BaseModel):
|
|
input_image: str
|
|
dino_model_name: str = "GroundingDINO_SwinT_OGC (694MB)"
|
|
text_prompt: str
|
|
box_threshold: float = 0.3
|
|
|
|
@app.post("/sam/dino-predict")
|
|
async def api_dino_predict(payload: DINOPredictRequest = Body(...)) -> Any:
|
|
print(f"SAM API /sam/dino-predict received request")
|
|
payload.input_image = decode_to_pil(payload.input_image)
|
|
dino_output_img, _, dino_msg = dino_predict(
|
|
payload.input_image,
|
|
payload.dino_model_name,
|
|
payload.text_prompt,
|
|
payload.box_threshold)
|
|
if "value" in dino_msg:
|
|
dino_msg = dino_msg["value"]
|
|
else:
|
|
dino_msg = "Done"
|
|
print(f"SAM API /sam/dino-predict finished with message: {dino_msg}")
|
|
return {
|
|
"msg": dino_msg,
|
|
"image_with_box": encode_to_base64(dino_output_img) if dino_output_img is not None else None,
|
|
}
|
|
|
|
class DilateMaskRequest(BaseModel):
|
|
input_image: str
|
|
mask: str
|
|
dilate_amount: int = 10
|
|
|
|
@app.post("/sam/dilate-mask")
|
|
async def api_dilate_mask(payload: DilateMaskRequest = Body(...)) -> Any:
|
|
print(f"SAM API /sam/dilate-mask received request")
|
|
payload.input_image = decode_to_pil(payload.input_image).convert("RGBA")
|
|
payload.mask = decode_to_pil(payload.mask)
|
|
dilate_result = list(map(encode_to_base64, update_mask(payload.mask, 0, payload.dilate_amount, payload.input_image)))
|
|
print(f"SAM API /sam/dilate-mask finished")
|
|
return {"blended_image": dilate_result[0], "mask": dilate_result[1], "masked_image": dilate_result[2]}
|
|
|
|
|
|
class AutoSAMConfig(BaseModel):
|
|
points_per_side: Optional[int] = 32
|
|
points_per_batch: int = 64
|
|
pred_iou_thresh: float = 0.88
|
|
stability_score_thresh: float = 0.95
|
|
stability_score_offset: float = 1.0
|
|
box_nms_thresh: float = 0.7
|
|
crop_n_layers: int = 0
|
|
crop_nms_thresh: float = 0.7
|
|
crop_overlap_ratio: float = 512 / 1500
|
|
crop_n_points_downscale_factor: int = 1
|
|
min_mask_region_area: int = 0
|
|
|
|
class ControlNetSegRequest(BaseModel):
|
|
sam_model_name: str = "sam_vit_h_4b8939.pth"
|
|
input_image: str
|
|
processor: str = "seg_ofade20k"
|
|
processor_res: int = 512
|
|
pixel_perfect: bool = False
|
|
resize_mode: Optional[int] = 1 # 0: just resize, 1: crop and resize, 2: resize and fill
|
|
target_W: Optional[int] = None
|
|
target_H: Optional[int] = None
|
|
|
|
@app.post("/sam/controlnet-seg")
|
|
async def api_controlnet_seg(payload: ControlNetSegRequest = Body(...),
|
|
autosam_conf: AutoSAMConfig = Body(...)) -> Any:
|
|
print(f"SAM API /sam/controlnet-seg received request")
|
|
payload.input_image = decode_to_pil(payload.input_image)
|
|
cnet_seg_img, cnet_seg_msg = cnet_seg(
|
|
payload.sam_model_name,
|
|
payload.input_image,
|
|
payload.processor,
|
|
payload.processor_res,
|
|
payload.pixel_perfect,
|
|
payload.resize_mode,
|
|
payload.target_W,
|
|
payload.target_H,
|
|
autosam_conf.points_per_side,
|
|
autosam_conf.points_per_batch,
|
|
autosam_conf.pred_iou_thresh,
|
|
autosam_conf.stability_score_thresh,
|
|
autosam_conf.stability_score_offset,
|
|
autosam_conf.box_nms_thresh,
|
|
autosam_conf.crop_n_layers,
|
|
autosam_conf.crop_nms_thresh,
|
|
autosam_conf.crop_overlap_ratio,
|
|
autosam_conf.crop_n_points_downscale_factor,
|
|
autosam_conf.min_mask_region_area)
|
|
cnet_seg_img = list(map(encode_to_base64, cnet_seg_img))
|
|
print(f"SAM API /sam/controlnet-seg finished with message {cnet_seg_msg}")
|
|
result = {
|
|
"msg": cnet_seg_msg,
|
|
}
|
|
if len(cnet_seg_img) == 3:
|
|
result["blended_images"] = cnet_seg_img[0]
|
|
result["random_seg"] = cnet_seg_img[1]
|
|
result["edit_anything_control"] = cnet_seg_img[2]
|
|
elif len(cnet_seg_img) == 4:
|
|
result["sem_presam"] = cnet_seg_img[0]
|
|
result["sem_postsam"] = cnet_seg_img[1]
|
|
result["blended_presam"] = cnet_seg_img[2]
|
|
result["blended_postsam"] = cnet_seg_img[3]
|
|
return result
|
|
|
|
class CategoryMaskRequest(BaseModel):
|
|
sam_model_name: str = "sam_vit_h_4b8939.pth"
|
|
processor: str = "seg_ofade20k"
|
|
processor_res: int = 512
|
|
pixel_perfect: bool = False
|
|
resize_mode: Optional[int] = 1
|
|
target_W: Optional[int] = None
|
|
target_H: Optional[int] = None
|
|
category: str
|
|
input_image: str
|
|
|
|
@app.post("/sam/category-mask")
|
|
async def api_category_mask(payload: CategoryMaskRequest = Body(...),
|
|
autosam_conf: AutoSAMConfig = Body(...)) -> Any:
|
|
print(f"SAM API /sam/category-mask received request")
|
|
payload.input_image = decode_to_pil(payload.input_image)
|
|
category_mask_img, category_mask_msg, resized_input_img = categorical_mask(
|
|
payload.sam_model_name,
|
|
payload.processor,
|
|
payload.processor_res,
|
|
payload.pixel_perfect,
|
|
payload.resize_mode,
|
|
payload.target_W,
|
|
payload.target_H,
|
|
payload.category,
|
|
payload.input_image,
|
|
autosam_conf.points_per_side,
|
|
autosam_conf.points_per_batch,
|
|
autosam_conf.pred_iou_thresh,
|
|
autosam_conf.stability_score_thresh,
|
|
autosam_conf.stability_score_offset,
|
|
autosam_conf.box_nms_thresh,
|
|
autosam_conf.crop_n_layers,
|
|
autosam_conf.crop_nms_thresh,
|
|
autosam_conf.crop_overlap_ratio,
|
|
autosam_conf.crop_n_points_downscale_factor,
|
|
autosam_conf.min_mask_region_area)
|
|
category_mask_img = list(map(encode_to_base64, category_mask_img))
|
|
print(f"SAM API /sam/category-mask finished with message {category_mask_msg}")
|
|
result = {
|
|
"msg": category_mask_msg,
|
|
}
|
|
if len(category_mask_img) == 3:
|
|
result["blended_image"] = category_mask_img[0]
|
|
result["mask"] = category_mask_img[1]
|
|
result["masked_image"] = category_mask_img[2]
|
|
if resized_input_img is not None:
|
|
result["resized_input"] = encode_to_base64(resized_input_img)
|
|
return result
|
|
|
|
|
|
try:
|
|
import modules.script_callbacks as script_callbacks
|
|
script_callbacks.on_app_started(sam_api)
|
|
except:
|
|
print("SAM Web UI API failed to initialize") |