Merge pull request #28 from jordan-barrett-jm/feature/add-api

Feature/add api
pull/30/head v1.0.1
Chengsong Zhang 2023-04-15 14:13:01 +08:00 committed by GitHub
commit dfc14911a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 7 deletions

View File

@ -75,6 +75,50 @@ Batch process image demo
| --- | --- | --- | --- |
| ![Input Image](https://user-images.githubusercontent.com/63914308/232157678-fcaaf6b6-1805-49fd-91fa-8a722cc01c8a.png) | ![Output Image](https://user-images.githubusercontent.com/63914308/232157721-2754ccf2-b341-4b24-95f2-b75ac5b4fcd2.png) | ![Output Mask](https://user-images.githubusercontent.com/63914308/232157975-05de0b23-1225-4187-89b1-032c731b46eb.png) | ![Output Blend](https://user-images.githubusercontent.com/63914308/232158575-228f687c-8045-4079-bcf5-5a4dd0c8d7bd.png)
### API Usage
We have added an API endpoint to allow for automated workflows.
The API utilizes both Segment Anything and GroundingDINO to return masks of all instances of whatever object is specified in the text prompt.
This is an extension of the existing [Stable Diffusion Web UI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API).
There are 2 endpoints exposed
- GET sam-webui/heartbeat
- POST /sam-webui/image-mask
The heartbeat endpoint can be used to ensure that the API is up.
The image-mask endpoint accepts a payload that includes your base64-encoded image.
Below is an example of how to interface with the API using requests.
#### API Example Usage
```
import base64
import requests
from PIL import Image
from io import BytesIO
def image_to_base64(img_path: str) -> str:
with open(img_path, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode()
return img_base64
payload = {
"image": image_to_base64("IMAGE_FILE_PATH"),
"prompt": "TEXT PROMPT",
"box_threshold": 0.3
}
res = requests.post(url, json=payload)
for dct in res.json():
image_data = base64.b64decode(dct['image'])
image = Image.open(BytesIO(image_data))
image.show()
```
## Contribute
Disclaimer: I have not thoroughly tested this extension, so there might be bugs. Bear with me while I'm fixing them :)

61
scripts/api.py Normal file
View File

@ -0,0 +1,61 @@
from fastapi import FastAPI, Body
from io import BytesIO
import base64
from pydantic import BaseModel
from typing import Any
import asyncio
import gradio as gr
import os
from scripts.sam import init_sam_model, dilate_mask, sam_predict, sam_model_list
from scripts.dino import dino_model_list
from PIL import Image, ImageChops
import base64
def sam_api(_: gr.Blocks, app: FastAPI):
@app.get("/sam-webui/heartbeat")
async def heartbeat():
return {
"msg": "Success!"
}
class MaskRequest(BaseModel):
image: str #base64 string containing image
prompt: str
box_threshold: float
def pil_image_to_base64(img: Image.Image) -> str:
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
return img_base64
@app.post("/sam-webui/image-mask")
async def process_image(payload: MaskRequest = Body(...)) -> Any:
sam_model_name = sam_model_list[0] if len(sam_model_list) > 0 else None
dino_model_name = dino_model_list[0] if len(dino_model_list) > 0 else None
# Decode the base64 image string
img_b64 = base64.b64decode(payload.image)
input_img = Image.open(BytesIO(img_b64))
#Run DINO and SAM inference to get masks back
masks = sam_predict(sam_model_name,
input_img,
[],
[],
True,
dino_model_name,
payload.prompt,
payload.box_threshold,
None,
None,
gui=False)
# Convert the final PIL image to a base64 string
response = [{"image": pil_image_to_base64(mask)} for mask in masks]
return response
try:
import modules.script_callbacks as script_callbacks
script_callbacks.on_app_started(sam_api)
except:
print("SAM Web UI API failed to initialize")

View File

@ -127,7 +127,7 @@ def init_sam_model(sam_model_name):
def sam_predict(sam_model_name, input_image, positive_points, negative_points,
dino_checkbox, dino_model_name, text_prompt, box_threshold,
dino_preview_checkbox, dino_preview_boxes_selection):
dino_preview_checkbox, dino_preview_boxes_selection, gui=True):
print("Start SAM Processing")
image_np = np.array(input_image)
image_np_rgb = image_np[..., :3]
@ -182,12 +182,13 @@ def sam_predict(sam_model_name, input_image, positive_points, negative_points,
boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None
for mask in masks:
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
masks_gallery.append(Image.fromarray(np.any(mask, axis=0)))
mask_images.append(Image.fromarray(blended_image))
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
matted_images.append(Image.fromarray(image_np_copy))
if gui:
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
mask_images.append(Image.fromarray(blended_image))
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
matted_images.append(Image.fromarray(image_np_copy))
return mask_images + masks_gallery + matted_images
@ -426,4 +427,4 @@ class Script(scripts.Script):
if not enabled or input_image is None or mask is None or not isinstance(p, StableDiffusionProcessingImg2Img):
return
p.init_images = [input_image]
p.image_mask = Image.open(expanded_mask[1]['name'] if dilation_enabled and expanded_mask is not None else mask[chosen_mask + 3]['name'])
p.image_mask = Image.open(expanded_mask[1]['name'] if dilation_enabled and expanded_mask is not None else mask[chosen_mask + 3]['name'])