diff --git a/scripts/api.py b/scripts/api.py index fd52ff9..0d602db 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -8,6 +8,7 @@ import numpy as np from modules.api.api import encode_pil_to_base64, decode_base64_to_image from scripts.sam import sam_predict, dino_predict, update_mask, cnet_seg, categorical_mask +from scripts.sam import sam_model_list def decode_to_pil(image): @@ -42,6 +43,10 @@ def sam_api(_: gr.Blocks, app: FastAPI): "msg": "Success!" } + @app.get("/sam/sam-model", description='Query available SAM model') + async def api_sam_model() -> List[str]: + return sam_model_list + class SamPredictRequest(BaseModel): sam_model_name: str = "sam_vit_h_4b8939.pth" input_image: str