diff --git a/modules/api/api.py b/modules/api/api.py index 80d3c6d34..78a7a04f9 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -76,7 +76,7 @@ class Api: # enumerator api self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=List[process.ItemPreprocess]) self.add_api_route("/sdapi/v1/masking", self.process.get_mask, methods=["GET"], response_model=process.ItemMask) - self.add_api_route("/sdapi/v1/interrogate", endpoints.get_interrogate, methods=["GET"], response_model=List[str], tags=["Caption"]) + self.add_api_route("/sdapi/v1/openclip", endpoints.get_caption, methods=["GET"], response_model=List[str], tags=["Caption"]) self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler]) self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler]) self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel]) @@ -91,7 +91,10 @@ class Api: # functional api self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo) - self.add_api_route("/sdapi/v1/interrogate", endpoints.post_interrogate, methods=["POST"], response_model=models.ResInterrogate, tags=["Caption"]) + # Caption dispatch endpoint (routes to openclip, tagger, or vlm based on 'backend' field) + self.add_api_route("/sdapi/v1/caption", endpoints.post_caption_dispatch, methods=["POST"], response_model=models.ResCaptionDispatch, tags=["Caption"]) + # Direct caption endpoints (bypass dispatch, use specific backend) + self.add_api_route("/sdapi/v1/openclip", endpoints.post_caption, methods=["POST"], response_model=models.ResCaption, tags=["Caption"]) self.add_api_route("/sdapi/v1/vqa", endpoints.post_vqa, methods=["POST"], response_model=models.ResVQA, tags=["Caption"]) self.add_api_route("/sdapi/v1/vqa/models", endpoints.get_vqa_models, methods=["GET"], response_model=List[models.ItemVLMModel], tags=["Caption"]) self.add_api_route("/sdapi/v1/vqa/prompts", endpoints.get_vqa_prompts, methods=["GET"], response_model=models.ResVLMPrompts, tags=["Caption"]) diff --git a/modules/api/endpoints.py b/modules/api/endpoints.py index 64b8b738d..6ad917da6 100644 --- a/modules/api/endpoints.py +++ b/modules/api/endpoints.py @@ -91,23 +91,21 @@ def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, f }) return res -def get_interrogate(): +def get_caption(): """ - List available interrogation models. + List available OpenCLIP caption models. - Returns model identifiers for use with POST /sdapi/v1/interrogate. + Returns model identifiers for use with POST /sdapi/v1/openclip or POST /sdapi/v1/caption (with backend="openclip"). **Model Types:** - OpenCLIP models: Format `architecture/pretrained_dataset` (e.g., `ViT-L-14/openai`) - For anime-style tagging (WaifuDiffusion, DeepBooru), use `/sdapi/v1/tagger` instead. - **Example Response:** ```json ["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"] ``` """ - from modules.interrogate.openclip import refresh_clip_models + from modules.caption.openclip import refresh_clip_models return refresh_clip_models() def get_schedulers(): @@ -117,9 +115,12 @@ def get_schedulers(): shared.log.critical(s) return all_schedulers -def post_interrogate(req: models.ReqInterrogate): +def post_caption(req: models.ReqCaption): """ - Interrogate an image using OpenCLIP/BLIP. + Caption an image using OpenCLIP/BLIP (direct endpoint). + + This is the direct endpoint for OpenCLIP captioning. For a unified interface + that can dispatch to OpenCLIP, Tagger, or VLM, use POST /sdapi/v1/caption instead. Analyze image using CLIP model via OpenCLIP to generate Stable Diffusion prompts. @@ -127,13 +128,11 @@ def post_interrogate(req: models.ReqInterrogate): - **Modes:** - `best`: Highest quality, combines multiple techniques - `fast`: Quick results with fewer iterations - - `classic`: Traditional CLIP interrogator style + - `classic`: Traditional CLIP captioner style - `caption`: BLIP caption only - `negative`: Generate negative prompt suggestions - Set `analyze=True` for detailed breakdown (medium, artist, movement, trending, flavor) - For anime/illustration tagging, use `/sdapi/v1/tagger` with WaifuDiffusion or DeepBooru models. - **Error Codes:** - 404: Image not provided or model not found """ @@ -141,7 +140,7 @@ def post_interrogate(req: models.ReqInterrogate): raise HTTPException(status_code=404, detail="Image not found") image = helpers.decode_base64_to_image(req.image) image = image.convert('RGB') - from modules.interrogate.openclip import interrogate_image, analyze_image, refresh_clip_models + from modules.caption.openclip import caption_image, analyze_image, refresh_clip_models if req.model not in refresh_clip_models(): raise HTTPException(status_code=404, detail="Model not found") # Build clip overrides from request (only include non-None values) @@ -161,11 +160,11 @@ def post_interrogate(req: models.ReqInterrogate): if req.num_beams is not None: clip_overrides['num_beams'] = req.num_beams try: - caption = interrogate_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode, overrides=clip_overrides if clip_overrides else None) + caption = caption_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode, overrides=clip_overrides if clip_overrides else None) except Exception as e: caption = str(e) if not req.analyze: - return models.ResInterrogate(caption=caption) + return models.ResCaption(caption=caption) analyze_results = analyze_image(image, clip_model=req.clip_model, blip_model=req.blip_model) # Extract top-ranked item from each Gradio update dict def get_top_item(result): @@ -181,7 +180,8 @@ def post_interrogate(req: models.ReqInterrogate): movement = get_top_item(analyze_results[2]) trending = get_top_item(analyze_results[3]) flavor = get_top_item(analyze_results[4]) - return models.ResInterrogate(caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor) + return models.ResCaption(caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor) + def post_vqa(req: models.ReqVQA): """ @@ -243,8 +243,8 @@ def post_vqa(req: models.ReqVQA): if req.keep_prefill is not None: generation_kwargs['keep_prefill'] = req.keep_prefill - from modules.interrogate import vqa - answer = vqa.interrogate( + from modules.caption import vqa + answer = vqa.caption( question=req.question, system_prompt=req.system, prompt=req.prompt or '', @@ -262,6 +262,201 @@ def post_vqa(req: models.ReqVQA): annotated_b64 = helpers.encode_pil_to_base64(annotated_img) return models.ResVQA(answer=answer, annotated_image=annotated_b64) + +def post_caption_dispatch(req: models.ReqCaptionDispatch): + """ + Unified caption endpoint - dispatches to OpenCLIP, Tagger, or VLM backends. + + This endpoint provides a single entry point for all captioning needs. Select the backend + using the `backend` field, then provide backend-specific parameters. + + **Backends:** + + 1. **OpenCLIP** (`backend: "openclip"`): + - CLIP/BLIP-based captioning for Stable Diffusion prompts + - Modes: best, fast, classic, caption, negative + - Set `analyze=True` for style breakdown + + 2. **Tagger** (`backend: "tagger"`): + - WaifuDiffusion or DeepBooru anime/illustration tagging + - Returns comma-separated booru-style tags + - Configurable thresholds for general and character tags + + 3. **VLM** (`backend: "vlm"`): + - Vision-Language Models (Qwen, Gemma, Florence, Moondream, etc.) + - Flexible tasks: captioning, Q&A, object detection, OCR + - Supports thinking mode for reasoning models + + **Direct Endpoints:** + For simpler requests, you can also use the direct endpoints: + - POST /sdapi/v1/openclip - OpenCLIP only + - POST /sdapi/v1/tagger - Tagger only + - POST /sdapi/v1/vqa - VLM only + """ + if req.backend == "openclip": + return _dispatch_openclip(req) + elif req.backend == "tagger": + return _dispatch_tagger(req) + elif req.backend == "vlm": + return _dispatch_vlm(req) + else: + raise HTTPException(status_code=400, detail=f"Unknown backend: {req.backend}") + + +def _dispatch_openclip(req: models.ReqCaptionOpenCLIP) -> models.ResCaptionDispatch: + """Handle OpenCLIP dispatch.""" + if req.image is None or len(req.image) < 64: + raise HTTPException(status_code=404, detail="Image not found") + image = helpers.decode_base64_to_image(req.image) + image = image.convert('RGB') + from modules.caption.openclip import caption_image, analyze_image, refresh_clip_models + if req.model not in refresh_clip_models(): + raise HTTPException(status_code=404, detail="Model not found") + # Build clip overrides from request + clip_overrides = {} + if req.min_length is not None: + clip_overrides['min_length'] = req.min_length + if req.max_length is not None: + clip_overrides['max_length'] = req.max_length + if req.chunk_size is not None: + clip_overrides['chunk_size'] = req.chunk_size + if req.min_flavors is not None: + clip_overrides['min_flavors'] = req.min_flavors + if req.max_flavors is not None: + clip_overrides['max_flavors'] = req.max_flavors + if req.flavor_count is not None: + clip_overrides['flavor_count'] = req.flavor_count + if req.num_beams is not None: + clip_overrides['num_beams'] = req.num_beams + try: + caption = caption_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode, overrides=clip_overrides if clip_overrides else None) + except Exception as e: + caption = str(e) + if not req.analyze: + return models.ResCaptionDispatch(backend="openclip", caption=caption) + analyze_results = analyze_image(image, clip_model=req.clip_model, blip_model=req.blip_model) + # Extract top-ranked item from each Gradio update dict + def get_top_item(result): + if isinstance(result, dict) and 'value' in result: + value = result['value'] + if isinstance(value, dict) and value: + return next(iter(value.keys())) + if isinstance(value, str): + return value + return None + medium = get_top_item(analyze_results[0]) + artist = get_top_item(analyze_results[1]) + movement = get_top_item(analyze_results[2]) + trending = get_top_item(analyze_results[3]) + flavor = get_top_item(analyze_results[4]) + return models.ResCaptionDispatch(backend="openclip", caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor) + + +def _dispatch_tagger(req: models.ReqCaptionTagger) -> models.ResCaptionDispatch: + """Handle Tagger dispatch.""" + if req.image is None or len(req.image) < 64: + raise HTTPException(status_code=404, detail="Image not found") + image = helpers.decode_base64_to_image(req.image) + image = image.convert('RGB') + from modules.caption import tagger + is_deepbooru = req.model.lower() in ('deepbooru', 'deepdanbooru') + # Store original settings + original_opts = { + 'tagger_threshold': shared.opts.tagger_threshold, + 'tagger_max_tags': shared.opts.tagger_max_tags, + 'tagger_include_rating': shared.opts.tagger_include_rating, + 'tagger_sort_alpha': shared.opts.tagger_sort_alpha, + 'tagger_use_spaces': shared.opts.tagger_use_spaces, + 'tagger_escape_brackets': shared.opts.tagger_escape_brackets, + 'tagger_exclude_tags': shared.opts.tagger_exclude_tags, + 'tagger_show_scores': shared.opts.tagger_show_scores, + } + if not is_deepbooru: + original_opts['waifudiffusion_character_threshold'] = shared.opts.waifudiffusion_character_threshold + original_opts['waifudiffusion_model'] = shared.opts.waifudiffusion_model + try: + shared.opts.tagger_threshold = req.threshold + shared.opts.tagger_max_tags = req.max_tags + shared.opts.tagger_include_rating = req.include_rating + shared.opts.tagger_sort_alpha = req.sort_alpha + shared.opts.tagger_use_spaces = req.use_spaces + shared.opts.tagger_escape_brackets = req.escape_brackets + shared.opts.tagger_exclude_tags = req.exclude_tags + shared.opts.tagger_show_scores = req.show_scores + if not is_deepbooru: + shared.opts.waifudiffusion_character_threshold = req.character_threshold + shared.opts.waifudiffusion_model = req.model + tags = tagger.tag(image, model_name='DeepBooru' if is_deepbooru else None) + # Parse scores if requested + scores = None + if req.show_scores: + scores = {} + for item in tags.split(', '): + item = item.strip() + if item.startswith('(') and item.endswith(')') and ':' in item: + inner = item[1:-1] + tag, score_str = inner.rsplit(':', 1) + try: + scores[tag.strip()] = float(score_str.strip()) + except ValueError: + pass + elif ':' in item: + tag, score_str = item.rsplit(':', 1) + try: + scores[tag.strip()] = float(score_str.strip()) + except ValueError: + pass + if not scores: + scores = None + return models.ResCaptionDispatch(backend="tagger", tags=tags, scores=scores) + finally: + for key, value in original_opts.items(): + setattr(shared.opts, key, value) + + +def _dispatch_vlm(req: models.ReqCaptionVLM) -> models.ResCaptionDispatch: + """Handle VLM dispatch.""" + if req.image is None or len(req.image) < 64: + raise HTTPException(status_code=404, detail="Image not found") + image = helpers.decode_base64_to_image(req.image) + image = image.convert('RGB') + # Build generation kwargs + generation_kwargs = {} + if req.max_tokens is not None: + generation_kwargs['max_tokens'] = req.max_tokens + if req.temperature is not None: + generation_kwargs['temperature'] = req.temperature + if req.top_k is not None: + generation_kwargs['top_k'] = req.top_k + if req.top_p is not None: + generation_kwargs['top_p'] = req.top_p + if req.num_beams is not None: + generation_kwargs['num_beams'] = req.num_beams + if req.do_sample is not None: + generation_kwargs['do_sample'] = req.do_sample + if req.keep_thinking is not None: + generation_kwargs['keep_thinking'] = req.keep_thinking + if req.keep_prefill is not None: + generation_kwargs['keep_prefill'] = req.keep_prefill + from modules.caption import vqa + answer = vqa.caption( + question=req.question, + system_prompt=req.system, + prompt=req.prompt or '', + image=image, + model_name=req.model, + prefill=req.prefill, + thinking_mode=req.thinking_mode, + generation_kwargs=generation_kwargs if generation_kwargs else None + ) + annotated_b64 = None + if req.include_annotated: + annotated_img = vqa.get_last_annotated_image() + if annotated_img is not None: + annotated_b64 = helpers.encode_pil_to_base64(annotated_img) + return models.ResCaptionDispatch(backend="vlm", answer=answer, annotated_image=annotated_b64) + + def get_vqa_models(): """ List available VLM models for captioning. @@ -274,7 +469,7 @@ def get_vqa_models(): - `prompts`: Available prompts/tasks - `capabilities`: Model features (caption, vqa, detection, ocr, thinking) """ - from modules.interrogate import vqa + from modules.caption import vqa models_list = [] for name, repo in vqa.vlm_models.items(): prompts = vqa.get_prompts_for_model(name) @@ -307,7 +502,7 @@ def get_vqa_prompts(model: Optional[str] = None): - Florence: Phrase Grounding, Object Detection, OCR, Dense Region Caption - Moondream: Point at..., Detect all..., Detect Gaze """ - from modules.interrogate import vqa + from modules.caption import vqa if model: prompts = vqa.get_prompts_for_model(model) return {"available": prompts} @@ -331,7 +526,7 @@ def get_tagger_models(): **DeepBooru:** - Legacy tagger for anime images """ - from modules.interrogate import waifudiffusion + from modules.caption import waifudiffusion models_list = [] # Add WaifuDiffusion models for name in waifudiffusion.get_models(): @@ -361,7 +556,7 @@ def post_tagger(req: models.ReqTagger): raise HTTPException(status_code=404, detail="Image not found") image = helpers.decode_base64_to_image(req.image) image = image.convert('RGB') - from modules.interrogate import tagger + from modules.caption import tagger # Determine if using DeepBooru is_deepbooru = req.model.lower() in ('deepbooru', 'deepdanbooru') # Store original settings and apply request settings diff --git a/modules/api/models.py b/modules/api/models.py index 86eb14886..329627bbc 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, Optional, Dict, List, Type, Callable, Union +from typing import Any, Optional, Dict, List, Type, Callable, Union, Literal, Annotated from pydantic import BaseModel, Field, create_model # pylint: disable=no-name-in-module from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img @@ -365,31 +365,29 @@ class ResStatus(BaseModel): eta: Optional[float] = Field(default=None, title="ETA in secs") progress: Optional[float] = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") -class ReqInterrogate(BaseModel): - """Request model for OpenCLIP/BLIP image interrogation. +class ReqCaption(BaseModel): + """Request model for OpenCLIP/BLIP image captioning. Analyze image using CLIP model via OpenCLIP to generate prompts. For anime-style tagging, use /sdapi/v1/tagger with WaifuDiffusion or DeepBooru. """ - image: str = Field(default="", title="Image", description="Image to interrogate. Must be a Base64 encoded string containing the image data (PNG/JPEG).") - model: str = Field(default="ViT-L-14/openai", title="Model", description="OpenCLIP model to use. Get available models from GET /sdapi/v1/interrogate.") + image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string containing the image data (PNG/JPEG).") + model: str = Field(default="ViT-L-14/openai", title="Model", description="OpenCLIP model to use. Get available models from GET /sdapi/v1/caption.") clip_model: str = Field(default="ViT-L-14/openai", title="CLIP Model", description="CLIP model used for image-text similarity matching. Larger models (ViT-L, ViT-H) are more accurate but slower and use more VRAM.") blip_model: str = Field(default="blip-large", title="Caption Model", description="BLIP model used to generate the initial image caption. The caption model describes the image content which CLIP then enriches with style and flavor terms.") - mode: str = Field(default="best", title="Mode", description="Interrogation mode. Fast: Quick caption with minimal flavor terms. Classic: Standard interrogation with balanced quality and speed. Best: Most thorough analysis, slowest but highest quality. Negative: Generate terms to use as negative prompt.") + mode: str = Field(default="best", title="Mode", description="Caption mode. Fast: Quick caption with minimal flavor terms. Classic: Standard captioning with balanced quality and speed. Best: Most thorough analysis, slowest but highest quality. Negative: Generate terms to use as negative prompt.") analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed image analysis breakdown (medium, artist, movement, trending, flavor) in addition to caption.") # Advanced settings (optional per-request overrides) min_length: Optional[int] = Field(default=None, title="Min Length", description="Minimum number of tokens in the generated caption.") max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.") - chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up interrogation but increase VRAM usage.") + chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up captioning but increase VRAM usage.") min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.") max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.") flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.") num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.") -InterrogateRequest = ReqInterrogate # alias for backwards compatibility - -class ResInterrogate(BaseModel): - """Response model for image interrogation results.""" +class ResCaption(BaseModel): + """Response model for image captioning results.""" caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption/prompt describing the image content and style.") medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (e.g., 'oil painting', 'digital art', 'photograph'). Only returned when analyze=True.") artist: Optional[str] = Field(default=None, title="Artist", description="Detected similar artist style (e.g., 'by greg rutkowski'). Only returned when analyze=True.") @@ -468,6 +466,109 @@ class ResTagger(BaseModel): tags: str = Field(title="Tags", description="Comma-separated list of detected tags") scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (when show_scores=True)") + +# ============================================================================= +# Caption Dispatch Models (Discriminated Union) +# ============================================================================= +# These models support the unified /sdapi/v1/caption dispatch endpoint. +# Users can also access backends directly via /openclip, /tagger, /vqa. + +class ReqCaptionOpenCLIP(BaseModel): + """OpenCLIP/BLIP caption request for the dispatch endpoint. + + Generate Stable Diffusion prompts using CLIP for image-text matching and BLIP for captioning. + Best for: General image captioning, prompt generation, style analysis. + """ + backend: Literal["openclip"] = Field(default="openclip", description="Backend selector. Use 'openclip' for CLIP/BLIP captioning.") + image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string containing the image data (PNG/JPEG).") + model: str = Field(default="ViT-L-14/openai", title="Model", description="OpenCLIP model to use. Get available models from GET /sdapi/v1/openclip.") + clip_model: str = Field(default="ViT-L-14/openai", title="CLIP Model", description="CLIP model used for image-text similarity matching. Larger models (ViT-L, ViT-H) are more accurate but slower.") + blip_model: str = Field(default="blip-large", title="Caption Model", description="BLIP model used to generate the initial image caption.") + mode: str = Field(default="best", title="Mode", description="Caption mode: 'best' (highest quality), 'fast' (quick), 'classic' (traditional), 'caption' (BLIP only), 'negative' (for negative prompts).") + analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed breakdown (medium, artist, movement, trending, flavor).") + min_length: Optional[int] = Field(default=None, title="Min Length", description="Minimum tokens in generated caption.") + max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum tokens in generated caption.") + chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing flavors.") + min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum descriptive tags to keep.") + max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum descriptive tags to keep.") + flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of intermediate candidate pool.") + num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beams for beam search during caption generation.") + + +class ReqCaptionTagger(BaseModel): + """Tagger request for the dispatch endpoint. + + Generate anime/illustration tags using WaifuDiffusion or DeepBooru models. + Best for: Anime images, booru-style tagging, character identification. + """ + backend: Literal["tagger"] = Field(..., description="Backend selector. Use 'tagger' for WaifuDiffusion/DeepBooru tagging.") + image: str = Field(default="", title="Image", description="Image to tag. Must be a Base64 encoded string.") + model: str = Field(default="wd-eva02-large-tagger-v3", title="Model", description="Tagger model. WaifuDiffusion (wd-*) or DeepBooru (deepbooru/deepdanbooru).") + threshold: float = Field(default=0.5, title="Threshold", description="Confidence threshold for general tags.", ge=0.0, le=1.0) + character_threshold: float = Field(default=0.85, title="Character Threshold", description="Confidence threshold for character tags (WaifuDiffusion only).", ge=0.0, le=1.0) + max_tags: int = Field(default=74, title="Max Tags", description="Maximum number of tags to return.", ge=1, le=512) + include_rating: bool = Field(default=False, title="Include Rating", description="Include content rating tags (safe/questionable/explicit).") + sort_alpha: bool = Field(default=False, title="Sort Alphabetically", description="Sort tags alphabetically instead of by confidence.") + use_spaces: bool = Field(default=False, title="Use Spaces", description="Replace underscores with spaces in tags.") + escape_brackets: bool = Field(default=True, title="Escape Brackets", description="Escape parentheses/brackets for prompt syntax.") + exclude_tags: str = Field(default="", title="Exclude Tags", description="Comma-separated tags to exclude from output.") + show_scores: bool = Field(default=False, title="Show Scores", description="Include confidence scores with tags.") + + +class ReqCaptionVLM(BaseModel): + """Vision-Language Model request for the dispatch endpoint. + + Flexible image understanding using modern VLMs (Qwen, Gemma, Florence, Moondream, etc.). + Best for: Detailed descriptions, question answering, object detection, OCR. + """ + backend: Literal["vlm"] = Field(..., description="Backend selector. Use 'vlm' for Vision-Language Model captioning.") + image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string.") + model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="VLM model to use. See GET /sdapi/v1/vqa/models for full list.") + question: str = Field(default="describe the image", title="Question/Task", description="Task to perform: 'Short Caption', 'Normal Caption', 'Long Caption', '' (detection), '' (text), etc.") + prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text when question is 'Use Prompt'.") + system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt for LLM behavior.") + include_annotated: bool = Field(default=False, title="Include Annotated Image", description="Return annotated image for detection tasks.") + max_tokens: Optional[int] = Field(default=None, title="Max Tokens", description="Maximum tokens in response.") + temperature: Optional[float] = Field(default=None, title="Temperature", description="Randomness in token selection (0=deterministic, 0.9=creative).") + top_k: Optional[int] = Field(default=None, title="Top-K", description="Limit to K most likely tokens per step.") + top_p: Optional[float] = Field(default=None, title="Top-P", description="Nucleus sampling threshold.") + num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Beam search width (1=disabled).") + do_sample: Optional[bool] = Field(default=None, title="Use Samplers", description="Enable sampling vs greedy decoding.") + thinking_mode: Optional[bool] = Field(default=None, title="Thinking Mode", description="Enable reasoning mode (supported models only).") + prefill: Optional[str] = Field(default=None, title="Prefill Text", description="Pre-fill response start to guide output.") + keep_thinking: Optional[bool] = Field(default=None, title="Keep Thinking Trace", description="Include reasoning in output.") + keep_prefill: Optional[bool] = Field(default=None, title="Keep Prefill", description="Keep prefill text in final output.") + + +# Discriminated union for the dispatch endpoint +ReqCaptionDispatch = Annotated[ + Union[ReqCaptionOpenCLIP, ReqCaptionTagger, ReqCaptionVLM], + Field(discriminator="backend") +] + + +class ResCaptionDispatch(BaseModel): + """Unified response for the caption dispatch endpoint. + + Contains fields from all backends - only relevant fields are populated based on the backend used. + """ + # Common + backend: str = Field(title="Backend", description="The backend that processed the request: 'openclip', 'tagger', or 'vlm'.") + # OpenCLIP fields + caption: Optional[str] = Field(default=None, title="Caption", description="Generated caption (OpenCLIP backend).") + medium: Optional[str] = Field(default=None, title="Medium", description="Detected artistic medium (OpenCLIP with analyze=True).") + artist: Optional[str] = Field(default=None, title="Artist", description="Detected artist style (OpenCLIP with analyze=True).") + movement: Optional[str] = Field(default=None, title="Movement", description="Detected art movement (OpenCLIP with analyze=True).") + trending: Optional[str] = Field(default=None, title="Trending", description="Trending tags (OpenCLIP with analyze=True).") + flavor: Optional[str] = Field(default=None, title="Flavor", description="Flavor descriptors (OpenCLIP with analyze=True).") + # Tagger fields + tags: Optional[str] = Field(default=None, title="Tags", description="Comma-separated tags (Tagger backend).") + scores: Optional[dict] = Field(default=None, title="Scores", description="Tag confidence scores (Tagger with show_scores=True).") + # VLM fields + answer: Optional[str] = Field(default=None, title="Answer", description="VLM response (VLM backend).") + annotated_image: Optional[str] = Field(default=None, title="Annotated Image", description="Base64 annotated image (VLM with include_annotated=True).") + + class ResTrain(BaseModel): info: str = Field(title="Train info", description="Response string from train embedding task.")