mirror of https://github.com/vladmandic/automatic
refactor: update API for caption module
Update API endpoints and models for caption module rename: - modules/api/api.py - update imports and endpoint handlers - modules/api/endpoints.py - update endpoint definitions - modules/api/models.py - update request/response modelspull/4613/head
parent
61b031ada5
commit
f4b5abde68
|
|
@ -76,7 +76,7 @@ class Api:
|
|||
# enumerator api
|
||||
self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=List[process.ItemPreprocess])
|
||||
self.add_api_route("/sdapi/v1/masking", self.process.get_mask, methods=["GET"], response_model=process.ItemMask)
|
||||
self.add_api_route("/sdapi/v1/interrogate", endpoints.get_interrogate, methods=["GET"], response_model=List[str], tags=["Caption"])
|
||||
self.add_api_route("/sdapi/v1/openclip", endpoints.get_caption, methods=["GET"], response_model=List[str], tags=["Caption"])
|
||||
self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler])
|
||||
self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler])
|
||||
self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel])
|
||||
|
|
@ -91,7 +91,10 @@ class Api:
|
|||
|
||||
# functional api
|
||||
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo)
|
||||
self.add_api_route("/sdapi/v1/interrogate", endpoints.post_interrogate, methods=["POST"], response_model=models.ResInterrogate, tags=["Caption"])
|
||||
# Caption dispatch endpoint (routes to openclip, tagger, or vlm based on 'backend' field)
|
||||
self.add_api_route("/sdapi/v1/caption", endpoints.post_caption_dispatch, methods=["POST"], response_model=models.ResCaptionDispatch, tags=["Caption"])
|
||||
# Direct caption endpoints (bypass dispatch, use specific backend)
|
||||
self.add_api_route("/sdapi/v1/openclip", endpoints.post_caption, methods=["POST"], response_model=models.ResCaption, tags=["Caption"])
|
||||
self.add_api_route("/sdapi/v1/vqa", endpoints.post_vqa, methods=["POST"], response_model=models.ResVQA, tags=["Caption"])
|
||||
self.add_api_route("/sdapi/v1/vqa/models", endpoints.get_vqa_models, methods=["GET"], response_model=List[models.ItemVLMModel], tags=["Caption"])
|
||||
self.add_api_route("/sdapi/v1/vqa/prompts", endpoints.get_vqa_prompts, methods=["GET"], response_model=models.ResVLMPrompts, tags=["Caption"])
|
||||
|
|
|
|||
|
|
@ -91,23 +91,21 @@ def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, f
|
|||
})
|
||||
return res
|
||||
|
||||
def get_interrogate():
|
||||
def get_caption():
|
||||
"""
|
||||
List available interrogation models.
|
||||
List available OpenCLIP caption models.
|
||||
|
||||
Returns model identifiers for use with POST /sdapi/v1/interrogate.
|
||||
Returns model identifiers for use with POST /sdapi/v1/openclip or POST /sdapi/v1/caption (with backend="openclip").
|
||||
|
||||
**Model Types:**
|
||||
- OpenCLIP models: Format `architecture/pretrained_dataset` (e.g., `ViT-L-14/openai`)
|
||||
|
||||
For anime-style tagging (WaifuDiffusion, DeepBooru), use `/sdapi/v1/tagger` instead.
|
||||
|
||||
**Example Response:**
|
||||
```json
|
||||
["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"]
|
||||
```
|
||||
"""
|
||||
from modules.interrogate.openclip import refresh_clip_models
|
||||
from modules.caption.openclip import refresh_clip_models
|
||||
return refresh_clip_models()
|
||||
|
||||
def get_schedulers():
|
||||
|
|
@ -117,9 +115,12 @@ def get_schedulers():
|
|||
shared.log.critical(s)
|
||||
return all_schedulers
|
||||
|
||||
def post_interrogate(req: models.ReqInterrogate):
|
||||
def post_caption(req: models.ReqCaption):
|
||||
"""
|
||||
Interrogate an image using OpenCLIP/BLIP.
|
||||
Caption an image using OpenCLIP/BLIP (direct endpoint).
|
||||
|
||||
This is the direct endpoint for OpenCLIP captioning. For a unified interface
|
||||
that can dispatch to OpenCLIP, Tagger, or VLM, use POST /sdapi/v1/caption instead.
|
||||
|
||||
Analyze image using CLIP model via OpenCLIP to generate Stable Diffusion prompts.
|
||||
|
||||
|
|
@ -127,13 +128,11 @@ def post_interrogate(req: models.ReqInterrogate):
|
|||
- **Modes:**
|
||||
- `best`: Highest quality, combines multiple techniques
|
||||
- `fast`: Quick results with fewer iterations
|
||||
- `classic`: Traditional CLIP interrogator style
|
||||
- `classic`: Traditional CLIP captioner style
|
||||
- `caption`: BLIP caption only
|
||||
- `negative`: Generate negative prompt suggestions
|
||||
- Set `analyze=True` for detailed breakdown (medium, artist, movement, trending, flavor)
|
||||
|
||||
For anime/illustration tagging, use `/sdapi/v1/tagger` with WaifuDiffusion or DeepBooru models.
|
||||
|
||||
**Error Codes:**
|
||||
- 404: Image not provided or model not found
|
||||
"""
|
||||
|
|
@ -141,7 +140,7 @@ def post_interrogate(req: models.ReqInterrogate):
|
|||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
image = helpers.decode_base64_to_image(req.image)
|
||||
image = image.convert('RGB')
|
||||
from modules.interrogate.openclip import interrogate_image, analyze_image, refresh_clip_models
|
||||
from modules.caption.openclip import caption_image, analyze_image, refresh_clip_models
|
||||
if req.model not in refresh_clip_models():
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
# Build clip overrides from request (only include non-None values)
|
||||
|
|
@ -161,11 +160,11 @@ def post_interrogate(req: models.ReqInterrogate):
|
|||
if req.num_beams is not None:
|
||||
clip_overrides['num_beams'] = req.num_beams
|
||||
try:
|
||||
caption = interrogate_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode, overrides=clip_overrides if clip_overrides else None)
|
||||
caption = caption_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode, overrides=clip_overrides if clip_overrides else None)
|
||||
except Exception as e:
|
||||
caption = str(e)
|
||||
if not req.analyze:
|
||||
return models.ResInterrogate(caption=caption)
|
||||
return models.ResCaption(caption=caption)
|
||||
analyze_results = analyze_image(image, clip_model=req.clip_model, blip_model=req.blip_model)
|
||||
# Extract top-ranked item from each Gradio update dict
|
||||
def get_top_item(result):
|
||||
|
|
@ -181,7 +180,8 @@ def post_interrogate(req: models.ReqInterrogate):
|
|||
movement = get_top_item(analyze_results[2])
|
||||
trending = get_top_item(analyze_results[3])
|
||||
flavor = get_top_item(analyze_results[4])
|
||||
return models.ResInterrogate(caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor)
|
||||
return models.ResCaption(caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor)
|
||||
|
||||
|
||||
def post_vqa(req: models.ReqVQA):
|
||||
"""
|
||||
|
|
@ -243,8 +243,8 @@ def post_vqa(req: models.ReqVQA):
|
|||
if req.keep_prefill is not None:
|
||||
generation_kwargs['keep_prefill'] = req.keep_prefill
|
||||
|
||||
from modules.interrogate import vqa
|
||||
answer = vqa.interrogate(
|
||||
from modules.caption import vqa
|
||||
answer = vqa.caption(
|
||||
question=req.question,
|
||||
system_prompt=req.system,
|
||||
prompt=req.prompt or '',
|
||||
|
|
@ -262,6 +262,201 @@ def post_vqa(req: models.ReqVQA):
|
|||
annotated_b64 = helpers.encode_pil_to_base64(annotated_img)
|
||||
return models.ResVQA(answer=answer, annotated_image=annotated_b64)
|
||||
|
||||
|
||||
def post_caption_dispatch(req: models.ReqCaptionDispatch):
|
||||
"""
|
||||
Unified caption endpoint - dispatches to OpenCLIP, Tagger, or VLM backends.
|
||||
|
||||
This endpoint provides a single entry point for all captioning needs. Select the backend
|
||||
using the `backend` field, then provide backend-specific parameters.
|
||||
|
||||
**Backends:**
|
||||
|
||||
1. **OpenCLIP** (`backend: "openclip"`):
|
||||
- CLIP/BLIP-based captioning for Stable Diffusion prompts
|
||||
- Modes: best, fast, classic, caption, negative
|
||||
- Set `analyze=True` for style breakdown
|
||||
|
||||
2. **Tagger** (`backend: "tagger"`):
|
||||
- WaifuDiffusion or DeepBooru anime/illustration tagging
|
||||
- Returns comma-separated booru-style tags
|
||||
- Configurable thresholds for general and character tags
|
||||
|
||||
3. **VLM** (`backend: "vlm"`):
|
||||
- Vision-Language Models (Qwen, Gemma, Florence, Moondream, etc.)
|
||||
- Flexible tasks: captioning, Q&A, object detection, OCR
|
||||
- Supports thinking mode for reasoning models
|
||||
|
||||
**Direct Endpoints:**
|
||||
For simpler requests, you can also use the direct endpoints:
|
||||
- POST /sdapi/v1/openclip - OpenCLIP only
|
||||
- POST /sdapi/v1/tagger - Tagger only
|
||||
- POST /sdapi/v1/vqa - VLM only
|
||||
"""
|
||||
if req.backend == "openclip":
|
||||
return _dispatch_openclip(req)
|
||||
elif req.backend == "tagger":
|
||||
return _dispatch_tagger(req)
|
||||
elif req.backend == "vlm":
|
||||
return _dispatch_vlm(req)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown backend: {req.backend}")
|
||||
|
||||
|
||||
def _dispatch_openclip(req: models.ReqCaptionOpenCLIP) -> models.ResCaptionDispatch:
|
||||
"""Handle OpenCLIP dispatch."""
|
||||
if req.image is None or len(req.image) < 64:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
image = helpers.decode_base64_to_image(req.image)
|
||||
image = image.convert('RGB')
|
||||
from modules.caption.openclip import caption_image, analyze_image, refresh_clip_models
|
||||
if req.model not in refresh_clip_models():
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
# Build clip overrides from request
|
||||
clip_overrides = {}
|
||||
if req.min_length is not None:
|
||||
clip_overrides['min_length'] = req.min_length
|
||||
if req.max_length is not None:
|
||||
clip_overrides['max_length'] = req.max_length
|
||||
if req.chunk_size is not None:
|
||||
clip_overrides['chunk_size'] = req.chunk_size
|
||||
if req.min_flavors is not None:
|
||||
clip_overrides['min_flavors'] = req.min_flavors
|
||||
if req.max_flavors is not None:
|
||||
clip_overrides['max_flavors'] = req.max_flavors
|
||||
if req.flavor_count is not None:
|
||||
clip_overrides['flavor_count'] = req.flavor_count
|
||||
if req.num_beams is not None:
|
||||
clip_overrides['num_beams'] = req.num_beams
|
||||
try:
|
||||
caption = caption_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode, overrides=clip_overrides if clip_overrides else None)
|
||||
except Exception as e:
|
||||
caption = str(e)
|
||||
if not req.analyze:
|
||||
return models.ResCaptionDispatch(backend="openclip", caption=caption)
|
||||
analyze_results = analyze_image(image, clip_model=req.clip_model, blip_model=req.blip_model)
|
||||
# Extract top-ranked item from each Gradio update dict
|
||||
def get_top_item(result):
|
||||
if isinstance(result, dict) and 'value' in result:
|
||||
value = result['value']
|
||||
if isinstance(value, dict) and value:
|
||||
return next(iter(value.keys()))
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return None
|
||||
medium = get_top_item(analyze_results[0])
|
||||
artist = get_top_item(analyze_results[1])
|
||||
movement = get_top_item(analyze_results[2])
|
||||
trending = get_top_item(analyze_results[3])
|
||||
flavor = get_top_item(analyze_results[4])
|
||||
return models.ResCaptionDispatch(backend="openclip", caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor)
|
||||
|
||||
|
||||
def _dispatch_tagger(req: models.ReqCaptionTagger) -> models.ResCaptionDispatch:
|
||||
"""Handle Tagger dispatch."""
|
||||
if req.image is None or len(req.image) < 64:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
image = helpers.decode_base64_to_image(req.image)
|
||||
image = image.convert('RGB')
|
||||
from modules.caption import tagger
|
||||
is_deepbooru = req.model.lower() in ('deepbooru', 'deepdanbooru')
|
||||
# Store original settings
|
||||
original_opts = {
|
||||
'tagger_threshold': shared.opts.tagger_threshold,
|
||||
'tagger_max_tags': shared.opts.tagger_max_tags,
|
||||
'tagger_include_rating': shared.opts.tagger_include_rating,
|
||||
'tagger_sort_alpha': shared.opts.tagger_sort_alpha,
|
||||
'tagger_use_spaces': shared.opts.tagger_use_spaces,
|
||||
'tagger_escape_brackets': shared.opts.tagger_escape_brackets,
|
||||
'tagger_exclude_tags': shared.opts.tagger_exclude_tags,
|
||||
'tagger_show_scores': shared.opts.tagger_show_scores,
|
||||
}
|
||||
if not is_deepbooru:
|
||||
original_opts['waifudiffusion_character_threshold'] = shared.opts.waifudiffusion_character_threshold
|
||||
original_opts['waifudiffusion_model'] = shared.opts.waifudiffusion_model
|
||||
try:
|
||||
shared.opts.tagger_threshold = req.threshold
|
||||
shared.opts.tagger_max_tags = req.max_tags
|
||||
shared.opts.tagger_include_rating = req.include_rating
|
||||
shared.opts.tagger_sort_alpha = req.sort_alpha
|
||||
shared.opts.tagger_use_spaces = req.use_spaces
|
||||
shared.opts.tagger_escape_brackets = req.escape_brackets
|
||||
shared.opts.tagger_exclude_tags = req.exclude_tags
|
||||
shared.opts.tagger_show_scores = req.show_scores
|
||||
if not is_deepbooru:
|
||||
shared.opts.waifudiffusion_character_threshold = req.character_threshold
|
||||
shared.opts.waifudiffusion_model = req.model
|
||||
tags = tagger.tag(image, model_name='DeepBooru' if is_deepbooru else None)
|
||||
# Parse scores if requested
|
||||
scores = None
|
||||
if req.show_scores:
|
||||
scores = {}
|
||||
for item in tags.split(', '):
|
||||
item = item.strip()
|
||||
if item.startswith('(') and item.endswith(')') and ':' in item:
|
||||
inner = item[1:-1]
|
||||
tag, score_str = inner.rsplit(':', 1)
|
||||
try:
|
||||
scores[tag.strip()] = float(score_str.strip())
|
||||
except ValueError:
|
||||
pass
|
||||
elif ':' in item:
|
||||
tag, score_str = item.rsplit(':', 1)
|
||||
try:
|
||||
scores[tag.strip()] = float(score_str.strip())
|
||||
except ValueError:
|
||||
pass
|
||||
if not scores:
|
||||
scores = None
|
||||
return models.ResCaptionDispatch(backend="tagger", tags=tags, scores=scores)
|
||||
finally:
|
||||
for key, value in original_opts.items():
|
||||
setattr(shared.opts, key, value)
|
||||
|
||||
|
||||
def _dispatch_vlm(req: models.ReqCaptionVLM) -> models.ResCaptionDispatch:
|
||||
"""Handle VLM dispatch."""
|
||||
if req.image is None or len(req.image) < 64:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
image = helpers.decode_base64_to_image(req.image)
|
||||
image = image.convert('RGB')
|
||||
# Build generation kwargs
|
||||
generation_kwargs = {}
|
||||
if req.max_tokens is not None:
|
||||
generation_kwargs['max_tokens'] = req.max_tokens
|
||||
if req.temperature is not None:
|
||||
generation_kwargs['temperature'] = req.temperature
|
||||
if req.top_k is not None:
|
||||
generation_kwargs['top_k'] = req.top_k
|
||||
if req.top_p is not None:
|
||||
generation_kwargs['top_p'] = req.top_p
|
||||
if req.num_beams is not None:
|
||||
generation_kwargs['num_beams'] = req.num_beams
|
||||
if req.do_sample is not None:
|
||||
generation_kwargs['do_sample'] = req.do_sample
|
||||
if req.keep_thinking is not None:
|
||||
generation_kwargs['keep_thinking'] = req.keep_thinking
|
||||
if req.keep_prefill is not None:
|
||||
generation_kwargs['keep_prefill'] = req.keep_prefill
|
||||
from modules.caption import vqa
|
||||
answer = vqa.caption(
|
||||
question=req.question,
|
||||
system_prompt=req.system,
|
||||
prompt=req.prompt or '',
|
||||
image=image,
|
||||
model_name=req.model,
|
||||
prefill=req.prefill,
|
||||
thinking_mode=req.thinking_mode,
|
||||
generation_kwargs=generation_kwargs if generation_kwargs else None
|
||||
)
|
||||
annotated_b64 = None
|
||||
if req.include_annotated:
|
||||
annotated_img = vqa.get_last_annotated_image()
|
||||
if annotated_img is not None:
|
||||
annotated_b64 = helpers.encode_pil_to_base64(annotated_img)
|
||||
return models.ResCaptionDispatch(backend="vlm", answer=answer, annotated_image=annotated_b64)
|
||||
|
||||
|
||||
def get_vqa_models():
|
||||
"""
|
||||
List available VLM models for captioning.
|
||||
|
|
@ -274,7 +469,7 @@ def get_vqa_models():
|
|||
- `prompts`: Available prompts/tasks
|
||||
- `capabilities`: Model features (caption, vqa, detection, ocr, thinking)
|
||||
"""
|
||||
from modules.interrogate import vqa
|
||||
from modules.caption import vqa
|
||||
models_list = []
|
||||
for name, repo in vqa.vlm_models.items():
|
||||
prompts = vqa.get_prompts_for_model(name)
|
||||
|
|
@ -307,7 +502,7 @@ def get_vqa_prompts(model: Optional[str] = None):
|
|||
- Florence: Phrase Grounding, Object Detection, OCR, Dense Region Caption
|
||||
- Moondream: Point at..., Detect all..., Detect Gaze
|
||||
"""
|
||||
from modules.interrogate import vqa
|
||||
from modules.caption import vqa
|
||||
if model:
|
||||
prompts = vqa.get_prompts_for_model(model)
|
||||
return {"available": prompts}
|
||||
|
|
@ -331,7 +526,7 @@ def get_tagger_models():
|
|||
**DeepBooru:**
|
||||
- Legacy tagger for anime images
|
||||
"""
|
||||
from modules.interrogate import waifudiffusion
|
||||
from modules.caption import waifudiffusion
|
||||
models_list = []
|
||||
# Add WaifuDiffusion models
|
||||
for name in waifudiffusion.get_models():
|
||||
|
|
@ -361,7 +556,7 @@ def post_tagger(req: models.ReqTagger):
|
|||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
image = helpers.decode_base64_to_image(req.image)
|
||||
image = image.convert('RGB')
|
||||
from modules.interrogate import tagger
|
||||
from modules.caption import tagger
|
||||
# Determine if using DeepBooru
|
||||
is_deepbooru = req.model.lower() in ('deepbooru', 'deepdanbooru')
|
||||
# Store original settings and apply request settings
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import inspect
|
||||
from typing import Any, Optional, Dict, List, Type, Callable, Union
|
||||
from typing import Any, Optional, Dict, List, Type, Callable, Union, Literal, Annotated
|
||||
from pydantic import BaseModel, Field, create_model # pylint: disable=no-name-in-module
|
||||
from inflection import underscore
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
|
|
@ -365,31 +365,29 @@ class ResStatus(BaseModel):
|
|||
eta: Optional[float] = Field(default=None, title="ETA in secs")
|
||||
progress: Optional[float] = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
|
||||
|
||||
class ReqInterrogate(BaseModel):
|
||||
"""Request model for OpenCLIP/BLIP image interrogation.
|
||||
class ReqCaption(BaseModel):
|
||||
"""Request model for OpenCLIP/BLIP image captioning.
|
||||
|
||||
Analyze image using CLIP model via OpenCLIP to generate prompts.
|
||||
For anime-style tagging, use /sdapi/v1/tagger with WaifuDiffusion or DeepBooru.
|
||||
"""
|
||||
image: str = Field(default="", title="Image", description="Image to interrogate. Must be a Base64 encoded string containing the image data (PNG/JPEG).")
|
||||
model: str = Field(default="ViT-L-14/openai", title="Model", description="OpenCLIP model to use. Get available models from GET /sdapi/v1/interrogate.")
|
||||
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string containing the image data (PNG/JPEG).")
|
||||
model: str = Field(default="ViT-L-14/openai", title="Model", description="OpenCLIP model to use. Get available models from GET /sdapi/v1/caption.")
|
||||
clip_model: str = Field(default="ViT-L-14/openai", title="CLIP Model", description="CLIP model used for image-text similarity matching. Larger models (ViT-L, ViT-H) are more accurate but slower and use more VRAM.")
|
||||
blip_model: str = Field(default="blip-large", title="Caption Model", description="BLIP model used to generate the initial image caption. The caption model describes the image content which CLIP then enriches with style and flavor terms.")
|
||||
mode: str = Field(default="best", title="Mode", description="Interrogation mode. Fast: Quick caption with minimal flavor terms. Classic: Standard interrogation with balanced quality and speed. Best: Most thorough analysis, slowest but highest quality. Negative: Generate terms to use as negative prompt.")
|
||||
mode: str = Field(default="best", title="Mode", description="Caption mode. Fast: Quick caption with minimal flavor terms. Classic: Standard captioning with balanced quality and speed. Best: Most thorough analysis, slowest but highest quality. Negative: Generate terms to use as negative prompt.")
|
||||
analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed image analysis breakdown (medium, artist, movement, trending, flavor) in addition to caption.")
|
||||
# Advanced settings (optional per-request overrides)
|
||||
min_length: Optional[int] = Field(default=None, title="Min Length", description="Minimum number of tokens in the generated caption.")
|
||||
max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.")
|
||||
chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up interrogation but increase VRAM usage.")
|
||||
chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up captioning but increase VRAM usage.")
|
||||
min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.")
|
||||
max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.")
|
||||
flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.")
|
||||
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.")
|
||||
|
||||
InterrogateRequest = ReqInterrogate # alias for backwards compatibility
|
||||
|
||||
class ResInterrogate(BaseModel):
|
||||
"""Response model for image interrogation results."""
|
||||
class ResCaption(BaseModel):
|
||||
"""Response model for image captioning results."""
|
||||
caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption/prompt describing the image content and style.")
|
||||
medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (e.g., 'oil painting', 'digital art', 'photograph'). Only returned when analyze=True.")
|
||||
artist: Optional[str] = Field(default=None, title="Artist", description="Detected similar artist style (e.g., 'by greg rutkowski'). Only returned when analyze=True.")
|
||||
|
|
@ -468,6 +466,109 @@ class ResTagger(BaseModel):
|
|||
tags: str = Field(title="Tags", description="Comma-separated list of detected tags")
|
||||
scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (when show_scores=True)")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Caption Dispatch Models (Discriminated Union)
|
||||
# =============================================================================
|
||||
# These models support the unified /sdapi/v1/caption dispatch endpoint.
|
||||
# Users can also access backends directly via /openclip, /tagger, /vqa.
|
||||
|
||||
class ReqCaptionOpenCLIP(BaseModel):
|
||||
"""OpenCLIP/BLIP caption request for the dispatch endpoint.
|
||||
|
||||
Generate Stable Diffusion prompts using CLIP for image-text matching and BLIP for captioning.
|
||||
Best for: General image captioning, prompt generation, style analysis.
|
||||
"""
|
||||
backend: Literal["openclip"] = Field(default="openclip", description="Backend selector. Use 'openclip' for CLIP/BLIP captioning.")
|
||||
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string containing the image data (PNG/JPEG).")
|
||||
model: str = Field(default="ViT-L-14/openai", title="Model", description="OpenCLIP model to use. Get available models from GET /sdapi/v1/openclip.")
|
||||
clip_model: str = Field(default="ViT-L-14/openai", title="CLIP Model", description="CLIP model used for image-text similarity matching. Larger models (ViT-L, ViT-H) are more accurate but slower.")
|
||||
blip_model: str = Field(default="blip-large", title="Caption Model", description="BLIP model used to generate the initial image caption.")
|
||||
mode: str = Field(default="best", title="Mode", description="Caption mode: 'best' (highest quality), 'fast' (quick), 'classic' (traditional), 'caption' (BLIP only), 'negative' (for negative prompts).")
|
||||
analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed breakdown (medium, artist, movement, trending, flavor).")
|
||||
min_length: Optional[int] = Field(default=None, title="Min Length", description="Minimum tokens in generated caption.")
|
||||
max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum tokens in generated caption.")
|
||||
chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing flavors.")
|
||||
min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum descriptive tags to keep.")
|
||||
max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum descriptive tags to keep.")
|
||||
flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of intermediate candidate pool.")
|
||||
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beams for beam search during caption generation.")
|
||||
|
||||
|
||||
class ReqCaptionTagger(BaseModel):
|
||||
"""Tagger request for the dispatch endpoint.
|
||||
|
||||
Generate anime/illustration tags using WaifuDiffusion or DeepBooru models.
|
||||
Best for: Anime images, booru-style tagging, character identification.
|
||||
"""
|
||||
backend: Literal["tagger"] = Field(..., description="Backend selector. Use 'tagger' for WaifuDiffusion/DeepBooru tagging.")
|
||||
image: str = Field(default="", title="Image", description="Image to tag. Must be a Base64 encoded string.")
|
||||
model: str = Field(default="wd-eva02-large-tagger-v3", title="Model", description="Tagger model. WaifuDiffusion (wd-*) or DeepBooru (deepbooru/deepdanbooru).")
|
||||
threshold: float = Field(default=0.5, title="Threshold", description="Confidence threshold for general tags.", ge=0.0, le=1.0)
|
||||
character_threshold: float = Field(default=0.85, title="Character Threshold", description="Confidence threshold for character tags (WaifuDiffusion only).", ge=0.0, le=1.0)
|
||||
max_tags: int = Field(default=74, title="Max Tags", description="Maximum number of tags to return.", ge=1, le=512)
|
||||
include_rating: bool = Field(default=False, title="Include Rating", description="Include content rating tags (safe/questionable/explicit).")
|
||||
sort_alpha: bool = Field(default=False, title="Sort Alphabetically", description="Sort tags alphabetically instead of by confidence.")
|
||||
use_spaces: bool = Field(default=False, title="Use Spaces", description="Replace underscores with spaces in tags.")
|
||||
escape_brackets: bool = Field(default=True, title="Escape Brackets", description="Escape parentheses/brackets for prompt syntax.")
|
||||
exclude_tags: str = Field(default="", title="Exclude Tags", description="Comma-separated tags to exclude from output.")
|
||||
show_scores: bool = Field(default=False, title="Show Scores", description="Include confidence scores with tags.")
|
||||
|
||||
|
||||
class ReqCaptionVLM(BaseModel):
|
||||
"""Vision-Language Model request for the dispatch endpoint.
|
||||
|
||||
Flexible image understanding using modern VLMs (Qwen, Gemma, Florence, Moondream, etc.).
|
||||
Best for: Detailed descriptions, question answering, object detection, OCR.
|
||||
"""
|
||||
backend: Literal["vlm"] = Field(..., description="Backend selector. Use 'vlm' for Vision-Language Model captioning.")
|
||||
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string.")
|
||||
model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="VLM model to use. See GET /sdapi/v1/vqa/models for full list.")
|
||||
question: str = Field(default="describe the image", title="Question/Task", description="Task to perform: 'Short Caption', 'Normal Caption', 'Long Caption', '<OD>' (detection), '<OCR>' (text), etc.")
|
||||
prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text when question is 'Use Prompt'.")
|
||||
system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt for LLM behavior.")
|
||||
include_annotated: bool = Field(default=False, title="Include Annotated Image", description="Return annotated image for detection tasks.")
|
||||
max_tokens: Optional[int] = Field(default=None, title="Max Tokens", description="Maximum tokens in response.")
|
||||
temperature: Optional[float] = Field(default=None, title="Temperature", description="Randomness in token selection (0=deterministic, 0.9=creative).")
|
||||
top_k: Optional[int] = Field(default=None, title="Top-K", description="Limit to K most likely tokens per step.")
|
||||
top_p: Optional[float] = Field(default=None, title="Top-P", description="Nucleus sampling threshold.")
|
||||
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beam search width (1=disabled).")
|
||||
do_sample: Optional[bool] = Field(default=None, title="Use Samplers", description="Enable sampling vs greedy decoding.")
|
||||
thinking_mode: Optional[bool] = Field(default=None, title="Thinking Mode", description="Enable reasoning mode (supported models only).")
|
||||
prefill: Optional[str] = Field(default=None, title="Prefill Text", description="Pre-fill response start to guide output.")
|
||||
keep_thinking: Optional[bool] = Field(default=None, title="Keep Thinking Trace", description="Include reasoning in output.")
|
||||
keep_prefill: Optional[bool] = Field(default=None, title="Keep Prefill", description="Keep prefill text in final output.")
|
||||
|
||||
|
||||
# Discriminated union for the dispatch endpoint
|
||||
ReqCaptionDispatch = Annotated[
|
||||
Union[ReqCaptionOpenCLIP, ReqCaptionTagger, ReqCaptionVLM],
|
||||
Field(discriminator="backend")
|
||||
]
|
||||
|
||||
|
||||
class ResCaptionDispatch(BaseModel):
|
||||
"""Unified response for the caption dispatch endpoint.
|
||||
|
||||
Contains fields from all backends - only relevant fields are populated based on the backend used.
|
||||
"""
|
||||
# Common
|
||||
backend: str = Field(title="Backend", description="The backend that processed the request: 'openclip', 'tagger', or 'vlm'.")
|
||||
# OpenCLIP fields
|
||||
caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption (OpenCLIP backend).")
|
||||
medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (OpenCLIP with analyze=True).")
|
||||
artist: Optional[str] = Field(default=None, title="Artist", description="Detected artist style (OpenCLIP with analyze=True).")
|
||||
movement: Optional[str] = Field(default=None, title="Movement", description="Detected art movement (OpenCLIP with analyze=True).")
|
||||
trending: Optional[str] = Field(default=None, title="Trending", description="Trending tags (OpenCLIP with analyze=True).")
|
||||
flavor: Optional[str] = Field(default=None, title="Flavor", description="Flavor descriptors (OpenCLIP with analyze=True).")
|
||||
# Tagger fields
|
||||
tags: Optional[str] = Field(default=None, title="Tags", description="Comma-separated tags (Tagger backend).")
|
||||
scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (Tagger with show_scores=True).")
|
||||
# VLM fields
|
||||
answer: Optional[str] = Field(default=None, title="Answer", description="VLM response (VLM backend).")
|
||||
annotated_image: Optional[str] = Field(default=None, title="Annotated Image", description="Base64 annotated image (VLM with include_annotated=True).")
|
||||
|
||||
|
||||
class ResTrain(BaseModel):
|
||||
info: str = Field(title="Train info", description="Response string from train embedding task.")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue