mirror of https://github.com/vladmandic/automatic
feat(api): add missing caption API parameters for UI parity
Add prompt field to VQA endpoint and advanced settings to OpenCLIP endpoint to achieve full parity between UI and API capabilities. VLM endpoint changes: - Add prompt field for custom text input (required for 'Use Prompt' task) - Pass prompt to vqa.interrogate instead of hardcoded empty string OpenCLIP endpoint changes: - Add 7 optional per-request override fields: min_length, max_length, chunk_size, min_flavors, max_flavors, flavor_count, num_beams - Add get_clip_setting() helper for override support in openclip.py - Apply overrides via update_interrogate_params() before interrogation All new fields are optional with None defaults for backwards compatibility.pull/4613/head
parent
5fc46c042e
commit
a04ba1e482
|
|
@ -151,8 +151,24 @@ def post_interrogate(req: models.ReqInterrogate):
|
|||
from modules.interrogate.openclip import interrogate_image, analyze_image, refresh_clip_models
|
||||
if req.model not in refresh_clip_models():
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
# Build clip overrides from request (only include non-None values)
|
||||
clip_overrides = {}
|
||||
if req.min_length is not None:
|
||||
clip_overrides['min_length'] = req.min_length
|
||||
if req.max_length is not None:
|
||||
clip_overrides['max_length'] = req.max_length
|
||||
if req.chunk_size is not None:
|
||||
clip_overrides['chunk_size'] = req.chunk_size
|
||||
if req.min_flavors is not None:
|
||||
clip_overrides['min_flavors'] = req.min_flavors
|
||||
if req.max_flavors is not None:
|
||||
clip_overrides['max_flavors'] = req.max_flavors
|
||||
if req.flavor_count is not None:
|
||||
clip_overrides['flavor_count'] = req.flavor_count
|
||||
if req.num_beams is not None:
|
||||
clip_overrides['num_beams'] = req.num_beams
|
||||
try:
|
||||
caption = interrogate_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode)
|
||||
caption = interrogate_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode, overrides=clip_overrides if clip_overrides else None)
|
||||
except Exception as e:
|
||||
caption = str(e)
|
||||
if not req.analyze:
|
||||
|
|
@ -239,7 +255,7 @@ def post_vqa(req: models.ReqVQA):
|
|||
answer = vqa.interrogate(
|
||||
question=req.question,
|
||||
system_prompt=req.system,
|
||||
prompt='',
|
||||
prompt=req.prompt or '',
|
||||
image=image,
|
||||
model_name=req.model,
|
||||
prefill=req.prefill,
|
||||
|
|
|
|||
|
|
@ -377,6 +377,14 @@ class ReqInterrogate(BaseModel):
|
|||
blip_model: str = Field(default="blip-large", title="Caption Model", description="BLIP model used to generate the initial image caption. The caption model describes the image content which CLIP then enriches with style and flavor terms.")
|
||||
mode: str = Field(default="best", title="Mode", description="Interrogation mode. Fast: Quick caption with minimal flavor terms. Classic: Standard interrogation with balanced quality and speed. Best: Most thorough analysis, slowest but highest quality. Negative: Generate terms to use as negative prompt.")
|
||||
analyze: bool = Field(default=False, title="Analyze", description="If True, returns detailed image analysis breakdown (medium, artist, movement, trending, flavor) in addition to caption.")
|
||||
# Advanced settings (optional per-request overrides)
|
||||
min_length: Optional[int] = Field(default=None, title="Min Length", description="Minimum number of tokens in the generated caption.")
|
||||
max_length: Optional[int] = Field(default=None, title="Max Length", description="Maximum number of tokens in the generated caption.")
|
||||
chunk_size: Optional[int] = Field(default=None, title="Chunk Size", description="Batch size for processing description candidates (flavors). Higher values speed up interrogation but increase VRAM usage.")
|
||||
min_flavors: Optional[int] = Field(default=None, title="Min Flavors", description="Minimum number of descriptive tags (flavors) to keep in the final prompt.")
|
||||
max_flavors: Optional[int] = Field(default=None, title="Max Flavors", description="Maximum number of descriptive tags (flavors) to keep in the final prompt.")
|
||||
flavor_count: Optional[int] = Field(default=None, title="Intermediates", description="Size of the intermediate candidate pool when matching image features to descriptive tags. Higher values may improve quality but are slower.")
|
||||
num_beams: Optional[int] = Field(default=None, title="Num Beams", description="Number of beams for beam search during caption generation. Higher values search more possibilities but are slower.")
|
||||
|
||||
InterrogateRequest = ReqInterrogate # alias for backwards compatibility
|
||||
|
||||
|
|
@ -398,6 +406,7 @@ class ReqVQA(BaseModel):
|
|||
image: str = Field(default="", title="Image", description="Image to caption. Must be a Base64 encoded string containing the image data.")
|
||||
model: str = Field(default="Alibaba Qwen 2.5 VL 3B", title="Model", description="Select which model to use for Visual Language tasks. Use GET /sdapi/v1/vqa/models for full list. Models which support thinking mode are indicated in capabilities.")
|
||||
question: str = Field(default="describe the image", title="Question/Task", description="Changes which task the model will perform. Regular text prompts can be used when the task is set to 'Use Prompt'. Common tasks: 'Short Caption', 'Normal Caption', 'Long Caption'. Florence-2 supports: '<OD>' (object detection), '<OCR>' (text recognition). Moondream supports: 'Point at [object]', 'Detect all [objects]'.")
|
||||
prompt: Optional[str] = Field(default=None, title="Prompt", description="Custom prompt text. Required when question is 'Use Prompt'. For 'Point at...' tasks, specify what to find (e.g., 'the red car'). For 'Detect all...' tasks, specify what to detect (e.g., 'faces').")
|
||||
system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System Prompt", description="System prompt controls behavior of the LLM. Processed first and persists throughout conversation. Has highest priority weighting and is always appended at the beginning of the sequence. Use for: Response formatting rules, role definition, style.")
|
||||
include_annotated: bool = Field(default=False, title="Include Annotated Image", description="If True and the task produces detection results (object detection, point detection, gaze), returns annotated image with bounding boxes/points drawn. Only applicable for detection tasks on models like Florence-2 and Moondream.")
|
||||
# LLM generation parameters (optional overrides)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,25 @@ from modules import devices, shared, errors, sd_models
|
|||
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
|
||||
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
|
||||
|
||||
# Per-request overrides for API calls
|
||||
_clip_overrides = None
|
||||
|
||||
|
||||
def get_clip_setting(name):
|
||||
"""Get CLIP setting with per-request override support.
|
||||
|
||||
Args:
|
||||
name: Setting name without 'interrogate_clip_' prefix (e.g., 'min_flavors', 'max_length')
|
||||
|
||||
Returns:
|
||||
Override value if set, otherwise the value from shared.opts
|
||||
"""
|
||||
if _clip_overrides is not None:
|
||||
value = _clip_overrides.get(name)
|
||||
if value is not None:
|
||||
return value
|
||||
return getattr(shared.opts, f'interrogate_clip_{name}')
|
||||
|
||||
|
||||
def _apply_blip2_fix(model, processor):
|
||||
"""Apply compatibility fix for BLIP2 models with newer transformers versions."""
|
||||
|
|
@ -70,11 +89,11 @@ class BatchWriter:
|
|||
|
||||
def update_interrogate_params():
|
||||
if ci is not None:
|
||||
ci.caption_max_length=shared.opts.interrogate_clip_max_length
|
||||
ci.chunk_size=shared.opts.interrogate_clip_chunk_size
|
||||
ci.flavor_intermediate_count=shared.opts.interrogate_clip_flavor_count
|
||||
ci.clip_offload=shared.opts.interrogate_offload
|
||||
ci.caption_offload=shared.opts.interrogate_offload
|
||||
ci.caption_max_length = get_clip_setting('max_length')
|
||||
ci.chunk_size = get_clip_setting('chunk_size')
|
||||
ci.flavor_intermediate_count = get_clip_setting('flavor_count')
|
||||
ci.clip_offload = shared.opts.interrogate_offload
|
||||
ci.caption_offload = shared.opts.interrogate_offload
|
||||
|
||||
|
||||
def get_clip_models():
|
||||
|
|
@ -159,33 +178,42 @@ def interrogate(image, mode, caption=None):
|
|||
return ''
|
||||
image = image.convert("RGB")
|
||||
t0 = time.time()
|
||||
debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={caption is not None} min_flavors={shared.opts.interrogate_clip_min_flavors} max_flavors={shared.opts.interrogate_clip_max_flavors}')
|
||||
min_flavors = get_clip_setting('min_flavors')
|
||||
max_flavors = get_clip_setting('max_flavors')
|
||||
debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={caption is not None} min_flavors={min_flavors} max_flavors={max_flavors}')
|
||||
if mode == 'best':
|
||||
prompt = ci.interrogate(image, caption=caption, min_flavors=shared.opts.interrogate_clip_min_flavors, max_flavors=shared.opts.interrogate_clip_max_flavors, )
|
||||
prompt = ci.interrogate(image, caption=caption, min_flavors=min_flavors, max_flavors=max_flavors)
|
||||
elif mode == 'caption':
|
||||
prompt = ci.generate_caption(image) if caption is None else caption
|
||||
elif mode == 'classic':
|
||||
prompt = ci.interrogate_classic(image, caption=caption, max_flavors=shared.opts.interrogate_clip_max_flavors)
|
||||
prompt = ci.interrogate_classic(image, caption=caption, max_flavors=max_flavors)
|
||||
elif mode == 'fast':
|
||||
prompt = ci.interrogate_fast(image, caption=caption, max_flavors=shared.opts.interrogate_clip_max_flavors)
|
||||
prompt = ci.interrogate_fast(image, caption=caption, max_flavors=max_flavors)
|
||||
elif mode == 'negative':
|
||||
prompt = ci.interrogate_negative(image, max_flavors=shared.opts.interrogate_clip_max_flavors)
|
||||
prompt = ci.interrogate_negative(image, max_flavors=max_flavors)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown mode {mode}")
|
||||
debug_log(f'CLIP: mode="{mode}" time={time.time()-t0:.2f} result="{prompt[:100]}..."' if len(prompt) > 100 else f'CLIP: mode="{mode}" time={time.time()-t0:.2f} result="{prompt}"')
|
||||
return prompt
|
||||
|
||||
|
||||
def interrogate_image(image, clip_model, blip_model, mode):
|
||||
def interrogate_image(image, clip_model, blip_model, mode, overrides=None):
|
||||
global _clip_overrides # pylint: disable=global-statement
|
||||
jobid = shared.state.begin('Interrogate CLiP')
|
||||
t0 = time.time()
|
||||
shared.log.info(f'CLIP: mode="{mode}" clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}')
|
||||
if overrides:
|
||||
debug_log(f'CLIP: overrides={overrides}')
|
||||
try:
|
||||
# Set per-request overrides
|
||||
_clip_overrides = overrides
|
||||
if shared.sd_loaded:
|
||||
from modules.sd_models import apply_balanced_offload # prevent circular import
|
||||
from modules.sd_models import apply_balanced_offload # prevent circular import
|
||||
apply_balanced_offload(shared.sd_model)
|
||||
debug_log('CLIP: applied balanced offload to sd_model')
|
||||
load_interrogator(clip_model, blip_model)
|
||||
# Apply overrides to loaded interrogator
|
||||
update_interrogate_params()
|
||||
image = image.convert('RGB')
|
||||
prompt = interrogate(image, mode)
|
||||
devices.torch_gc()
|
||||
|
|
@ -194,6 +222,9 @@ def interrogate_image(image, clip_model, blip_model, mode):
|
|||
prompt = f"Exception {type(e)}"
|
||||
shared.log.error(f'CLIP: {e}')
|
||||
errors.display(e, 'Interrogate')
|
||||
finally:
|
||||
# Clear per-request overrides
|
||||
_clip_overrides = None
|
||||
shared.state.end(jobid)
|
||||
return prompt
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue