feat(api): add missing caption API parameters for UI parity

Add prompt field to VQA endpoint and advanced settings to OpenCLIP endpoint
to achieve full parity between UI and API capabilities.

VLM endpoint changes:
- Add prompt field for custom text input (required for 'Use Prompt' task)
- Pass prompt to vqa.interrogate instead of hardcoded empty string

OpenCLIP endpoint changes:
- Add 7 optional per-request override fields: min_length, max_length,
  chunk_size, min_flavors, max_flavors, flavor_count, num_beams
- Add get_clip_setting() helper for override support in openclip.py
- Apply overrides via update_interrogate_params() before interrogation

All new fields are optional with None defaults for backwards compatibility.
pull/4613/head
CalamitousFelicitousness 2026-01-25 01:31:18 +00:00
parent 5fc46c042e
commit a04ba1e482
3 changed files with 70 additions and 14 deletions

View File

@ -151,8 +151,24 @@ def post_interrogate(req: models.ReqInterrogate):
from modules.interrogate.openclip import interrogate_image, analyze_image, refresh_clip_models
if req.model not in refresh_clip_models():
raise HTTPException(status_code=404, detail="Model not found")
# Build clip overrides from request (only include non-None values)
clip_overrides = {}
if req.min_length is not None:
clip_overrides['min_length'] = req.min_length
if req.max_length is not None:
clip_overrides['max_length'] = req.max_length
if req.chunk_size is not None:
clip_overrides['chunk_size'] = req.chunk_size
if req.min_flavors is not None:
clip_overrides['min_flavors'] = req.min_flavors
if req.max_flavors is not None:
clip_overrides['max_flavors'] = req.max_flavors
if req.flavor_count is not None:
clip_overrides['flavor_count'] = req.flavor_count
if req.num_beams is not None:
clip_overrides['num_beams'] = req.num_beams
try:
caption = interrogate_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode)
caption = interrogate_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode, overrides=clip_overrides if clip_overrides else None)
except Exception as e:
caption = str(e)
if not req.analyze:
@ -239,7 +255,7 @@ def post_vqa(req: models.ReqVQA):
answer = vqa.interrogate(
question=req.question,
system_prompt=req.system,
prompt='',
prompt=req.prompt or '',
image=image,
model_name=req.model,
prefill=req.prefill,

View File

@ -377,6 +377,14 @@ class ReqInterrogate(BaseModel):
blip_model: str = Field(default="blip-large", title="Caption Model", description="BLIP model used to generate the initial image caption. The caption model describes the image content which CLIP then enriches with style and flavor terms.")
mode: str = Field(default="best", title="Mode", description="Interrogation mode. Fast: Quick caption with minimal flavor terms. Classic: Standard interrogation with balanced quality and speed. Best: Most thorough analysis, slowest but highest quality. Negative: Generate terms to use as negative prompt.")
analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed image analysis breakdown (medium, artist, movement, trending, flavor) in addition to caption.")
# Advanced settings (optional per-request overrides)
min_length: Optional[int] = Field(default=None, title="Min Length", description="Minimum number of tokens in the generated caption.")
max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.")
chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up interrogation but increase VRAM usage.")
min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.")
max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.")
flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.")
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.")
InterrogateRequest = ReqInterrogate # alias for backwards compatibility
@ -398,6 +406,7 @@ class ReqVQA(BaseModel):
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string containing the image data.")
model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="Select which model to use for Visual Language tasks. Use GET /sdapi/v1/vqa/models for full list. Models which support thinking mode are indicated in capabilities.")
question: str = Field(default="describe the image", title="Question/Task", description="Changes which task the model will perform. Regular text prompts can be used when the task is set to 'Use Prompt'. Common tasks: 'Short Caption', 'Normal Caption', 'Long Caption'. Florence-2 supports: '<OD>' (object detection), '<OCR>' (text recognition). Moondream supports: 'Point at [object]', 'Detect all [objects]'.")
prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text. Required when question is 'Use Prompt'. For 'Point at...' tasks, specify what to find (e.g., 'the red car'). For 'Detect all...' tasks, specify what to detect (e.g., 'faces').")
system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt controls behavior of the LLM. Processed first and persists throughout conversation. Has highest priority weighting and is always appended at the beginning of the sequence. Use for: Response formatting rules, role definition, style.")
include_annotated: bool = Field(default=False, title="Include Annotated Image", description="If True and the task produces detection results (object detection, point detection, gaze), returns annotated image with bounding boxes/points drawn. Only applicable for detection tasks on models like Florence-2 and Moondream.")
# LLM generation parameters (optional overrides)

View File

@ -11,6 +11,25 @@ from modules import devices, shared, errors, sd_models
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
# Per-request overrides for API calls
_clip_overrides = None
def get_clip_setting(name):
"""Get CLIP setting with per-request override support.
Args:
name: Setting name without 'interrogate_clip_' prefix (e.g., 'min_flavors', 'max_length')
Returns:
Override value if set, otherwise the value from shared.opts
"""
if _clip_overrides is not None:
value = _clip_overrides.get(name)
if value is not None:
return value
return getattr(shared.opts, f'interrogate_clip_{name}')
def _apply_blip2_fix(model, processor):
"""Apply compatibility fix for BLIP2 models with newer transformers versions."""
@ -70,11 +89,11 @@ class BatchWriter:
def update_interrogate_params():
if ci is not None:
ci.caption_max_length=shared.opts.interrogate_clip_max_length
ci.chunk_size=shared.opts.interrogate_clip_chunk_size
ci.flavor_intermediate_count=shared.opts.interrogate_clip_flavor_count
ci.clip_offload=shared.opts.interrogate_offload
ci.caption_offload=shared.opts.interrogate_offload
ci.caption_max_length = get_clip_setting('max_length')
ci.chunk_size = get_clip_setting('chunk_size')
ci.flavor_intermediate_count = get_clip_setting('flavor_count')
ci.clip_offload = shared.opts.interrogate_offload
ci.caption_offload = shared.opts.interrogate_offload
def get_clip_models():
@ -159,33 +178,42 @@ def interrogate(image, mode, caption=None):
return ''
image = image.convert("RGB")
t0 = time.time()
debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={caption is not None} min_flavors={shared.opts.interrogate_clip_min_flavors} max_flavors={shared.opts.interrogate_clip_max_flavors}')
min_flavors = get_clip_setting('min_flavors')
max_flavors = get_clip_setting('max_flavors')
debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={caption is not None} min_flavors={min_flavors} max_flavors={max_flavors}')
if mode == 'best':
prompt = ci.interrogate(image, caption=caption, min_flavors=shared.opts.interrogate_clip_min_flavors, max_flavors=shared.opts.interrogate_clip_max_flavors, )
prompt = ci.interrogate(image, caption=caption, min_flavors=min_flavors, max_flavors=max_flavors)
elif mode == 'caption':
prompt = ci.generate_caption(image) if caption is None else caption
elif mode == 'classic':
prompt = ci.interrogate_classic(image, caption=caption, max_flavors=shared.opts.interrogate_clip_max_flavors)
prompt = ci.interrogate_classic(image, caption=caption, max_flavors=max_flavors)
elif mode == 'fast':
prompt = ci.interrogate_fast(image, caption=caption, max_flavors=shared.opts.interrogate_clip_max_flavors)
prompt = ci.interrogate_fast(image, caption=caption, max_flavors=max_flavors)
elif mode == 'negative':
prompt = ci.interrogate_negative(image, max_flavors=shared.opts.interrogate_clip_max_flavors)
prompt = ci.interrogate_negative(image, max_flavors=max_flavors)
else:
raise RuntimeError(f"Unknown mode {mode}")
debug_log(f'CLIP: mode="{mode}" time={time.time()-t0:.2f} result="{prompt[:100]}..."' if len(prompt) > 100 else f'CLIP: mode="{mode}" time={time.time()-t0:.2f} result="{prompt}"')
return prompt
def interrogate_image(image, clip_model, blip_model, mode):
def interrogate_image(image, clip_model, blip_model, mode, overrides=None):
global _clip_overrides # pylint: disable=global-statement
jobid = shared.state.begin('Interrogate CLiP')
t0 = time.time()
shared.log.info(f'CLIP: mode="{mode}" clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}')
if overrides:
debug_log(f'CLIP: overrides={overrides}')
try:
# Set per-request overrides
_clip_overrides = overrides
if shared.sd_loaded:
from modules.sd_models import apply_balanced_offload # prevent circular import
from modules.sd_models import apply_balanced_offload # prevent circular import
apply_balanced_offload(shared.sd_model)
debug_log('CLIP: applied balanced offload to sd_model')
load_interrogator(clip_model, blip_model)
# Apply overrides to loaded interrogator
update_interrogate_params()
image = image.convert('RGB')
prompt = interrogate(image, mode)
devices.torch_gc()
@ -194,6 +222,9 @@ def interrogate_image(image, clip_model, blip_model, mode):
prompt = f"Exception {type(e)}"
shared.log.error(f'CLIP: {e}')
errors.display(e, 'Interrogate')
finally:
# Clear per-request overrides
_clip_overrides = None
shared.state.end(jobid)
return prompt