commit
dfc14911a3
44
README.md
44
README.md
|
|
@ -75,6 +75,50 @@ Batch process image demo
|
|||
| --- | --- | --- | --- |
|
||||
|  |  |  | 
|
||||
|
||||
### API Usage
|
||||
|
||||
We have added an API endpoint to allow for automated workflows.
|
||||
|
||||
The API utilizes both Segment Anything and GroundingDINO to return masks of all instances of whatever object is specified in the text prompt.
|
||||
|
||||
This is an extension of the existing [Stable Diffusion Web UI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API).
|
||||
|
||||
There are 2 endpoints exposed
|
||||
- GET sam-webui/heartbeat
|
||||
- POST /sam-webui/image-mask
|
||||
|
||||
The heartbeat endpoint can be used to ensure that the API is up.
|
||||
|
||||
The image-mask endpoint accepts a payload that includes your base64-encoded image.
|
||||
|
||||
Below is an example of how to interface with the API using requests.
|
||||
|
||||
#### API Example Usage
|
||||
|
||||
```
|
||||
import base64
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
def image_to_base64(img_path: str) -> str:
|
||||
with open(img_path, "rb") as img_file:
|
||||
img_base64 = base64.b64encode(img_file.read()).decode()
|
||||
return img_base64
|
||||
|
||||
payload = {
|
||||
"image": image_to_base64("IMAGE_FILE_PATH"),
|
||||
"prompt": "TEXT PROMPT",
|
||||
"box_threshold": 0.3
|
||||
}
|
||||
res = requests.post(url, json=payload)
|
||||
|
||||
for dct in res.json():
|
||||
image_data = base64.b64decode(dct['image'])
|
||||
image = Image.open(BytesIO(image_data))
|
||||
image.show()
|
||||
```
|
||||
|
||||
## Contribute
|
||||
|
||||
Disclaimer: I have not thoroughly tested this extension, so there might be bugs. Bear with me while I'm fixing them :)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
from fastapi import FastAPI, Body
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
import asyncio
|
||||
import gradio as gr
|
||||
import os
|
||||
from scripts.sam import init_sam_model, dilate_mask, sam_predict, sam_model_list
|
||||
from scripts.dino import dino_model_list
|
||||
from PIL import Image, ImageChops
|
||||
import base64
|
||||
|
||||
def sam_api(_: gr.Blocks, app: FastAPI):
|
||||
@app.get("/sam-webui/heartbeat")
|
||||
async def heartbeat():
|
||||
return {
|
||||
"msg": "Success!"
|
||||
}
|
||||
|
||||
class MaskRequest(BaseModel):
|
||||
image: str #base64 string containing image
|
||||
prompt: str
|
||||
box_threshold: float
|
||||
|
||||
def pil_image_to_base64(img: Image.Image) -> str:
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
return img_base64
|
||||
|
||||
@app.post("/sam-webui/image-mask")
|
||||
async def process_image(payload: MaskRequest = Body(...)) -> Any:
|
||||
sam_model_name = sam_model_list[0] if len(sam_model_list) > 0 else None
|
||||
dino_model_name = dino_model_list[0] if len(dino_model_list) > 0 else None
|
||||
# Decode the base64 image string
|
||||
img_b64 = base64.b64decode(payload.image)
|
||||
input_img = Image.open(BytesIO(img_b64))
|
||||
#Run DINO and SAM inference to get masks back
|
||||
masks = sam_predict(sam_model_name,
|
||||
input_img,
|
||||
[],
|
||||
[],
|
||||
True,
|
||||
dino_model_name,
|
||||
payload.prompt,
|
||||
payload.box_threshold,
|
||||
None,
|
||||
None,
|
||||
gui=False)
|
||||
|
||||
# Convert the final PIL image to a base64 string
|
||||
response = [{"image": pil_image_to_base64(mask)} for mask in masks]
|
||||
|
||||
return response
|
||||
|
||||
try:
|
||||
import modules.script_callbacks as script_callbacks
|
||||
script_callbacks.on_app_started(sam_api)
|
||||
except:
|
||||
print("SAM Web UI API failed to initialize")
|
||||
|
|
@ -127,7 +127,7 @@ def init_sam_model(sam_model_name):
|
|||
|
||||
def sam_predict(sam_model_name, input_image, positive_points, negative_points,
|
||||
dino_checkbox, dino_model_name, text_prompt, box_threshold,
|
||||
dino_preview_checkbox, dino_preview_boxes_selection):
|
||||
dino_preview_checkbox, dino_preview_boxes_selection, gui=True):
|
||||
print("Start SAM Processing")
|
||||
image_np = np.array(input_image)
|
||||
image_np_rgb = image_np[..., :3]
|
||||
|
|
@ -182,12 +182,13 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
|
|||
|
||||
boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None
|
||||
for mask in masks:
|
||||
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
|
||||
masks_gallery.append(Image.fromarray(np.any(mask, axis=0)))
|
||||
mask_images.append(Image.fromarray(blended_image))
|
||||
image_np_copy = copy.deepcopy(image_np)
|
||||
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
|
||||
matted_images.append(Image.fromarray(image_np_copy))
|
||||
if gui:
|
||||
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
|
||||
mask_images.append(Image.fromarray(blended_image))
|
||||
image_np_copy = copy.deepcopy(image_np)
|
||||
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
|
||||
matted_images.append(Image.fromarray(image_np_copy))
|
||||
|
||||
return mask_images + masks_gallery + matted_images
|
||||
|
||||
|
|
@ -426,4 +427,4 @@ class Script(scripts.Script):
|
|||
if not enabled or input_image is None or mask is None or not isinstance(p, StableDiffusionProcessingImg2Img):
|
||||
return
|
||||
p.init_images = [input_image]
|
||||
p.image_mask = Image.open(expanded_mask[1]['name'] if dilation_enabled and expanded_mask is not None else mask[chosen_mask + 3]['name'])
|
||||
p.image_mask = Image.open(expanded_mask[1]['name'] if dilation_enabled and expanded_mask is not None else mask[chosen_mask + 3]['name'])
|
||||
Loading…
Reference in New Issue