import io
import os
import time
import json
import base64
import copy
import torch
import transformers
import transformers.dynamic_module_utils
from PIL import Image
from modules import shared, devices, errors, model_quant, sd_models, sd_models_compile
from modules.logger import log, console
from modules.caption import vqa_detection
from modules.caption.models_def import vlm_models, vlm_system, vlm_default, vlm_prefill, vlm_prompts, vlm_prompt_mapping, vlm_prompt_placeholders, vlm_prompts_common, vlm_prompts_florence, vlm_prompts_moondream, vlm_prompts_moondream2, vlm_prompts_promptgen
# Debug logging - function-based to avoid circular import
debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None
def debug(*args, **kwargs):
if debug_enabled:
log.trace(*args, **kwargs)
def get_prompts_for_model(model_name: str) -> list:
"""Get available prompts based on selected model."""
if model_name is None:
return vlm_prompts_common
model_lower = model_name.lower()
# Check for PromptGen models (MiaoshouAI fine-tunes with extra prompts)
if 'promptgen' in model_lower:
return vlm_prompts_common + vlm_prompts_florence + vlm_prompts_promptgen
# Check for Florence-2 base / CogFlorence models (no PromptGen-specific prompts)
if 'florence' in model_lower:
return vlm_prompts_common + vlm_prompts_florence
# Check for Moondream models (Moondream 2 has gaze detection, Moondream 3 does not)
if 'moondream' in model_lower:
if 'moondream3' in model_lower or 'moondream 3' in model_lower:
return vlm_prompts_common + vlm_prompts_moondream
else: # Moondream 2 includes gaze detection
return vlm_prompts_common + vlm_prompts_moondream + vlm_prompts_moondream2
# Default: common prompts only
return vlm_prompts_common
def get_internal_prompt(friendly_name: str, user_prompt: str = None) -> str:
"""Convert friendly prompt name to internal token/command."""
internal = vlm_prompt_mapping.get(friendly_name, friendly_name)
# Handle Moondream point/detect modes - prepend trigger phrase
if internal == "POINT_MODE" and user_prompt:
return f"Point at {user_prompt}"
elif internal == "DETECT_MODE" and user_prompt:
return f"Detect {user_prompt}"
return internal
def get_prompt_placeholder(friendly_name: str) -> str:
"""Get placeholder text for the prompt field based on selected question."""
return vlm_prompt_placeholders.get(friendly_name, "Enter your question or instruction")
def is_florence_task(question: str) -> bool:
"""Check if the question is a Florence-2 task token (either friendly name or internal token).
This includes both base Florence prompts and PromptGen-specific prompts,
since all are handled by the Florence handler.
"""
if not question:
return False
# Check if it's a Florence-specific friendly name (base or PromptGen)
if question in vlm_prompts_florence or question in vlm_prompts_promptgen:
return True
# Check if it's an internal Florence-2 task token (for backwards compatibility)
florence_base_tokens = ['
', '', '', '',
'', '', '', '', '']
promptgen_tokens = ['', '', '', '']
return question in florence_base_tokens or question in promptgen_tokens
def is_thinking_model(model_name: str) -> bool:
"""Check if the model supports thinking mode based on its name."""
if not model_name:
return False
model_lower = model_name.lower()
# Check for known thinking models
thinking_indicators = [
'thinking', # Qwen3-VL-*-Thinking models
'moondream3', # Moondream 3 supports thinking
'moondream 3',
'moondream2', # Moondream 2 supports reasoning mode
'moondream 2',
'mimo',
]
return any(indicator in model_lower for indicator in thinking_indicators)
def truncate_b64_in_conversation(conversation, front_chars=50, tail_chars=50, threshold=200):
"""
Deep copy a conversation structure and truncate long base64 image strings for logging.
Preserves front and tail of base64 strings with truncation indicator.
"""
conv_copy = copy.deepcopy(conversation)
def truncate_recursive(obj):
if isinstance(obj, dict):
for key, value in obj.items():
if key == "image" and isinstance(value, str) and len(value) > threshold:
# Truncate the base64 image string
truncated_count = len(value) - front_chars - tail_chars
obj[key] = f"{value[:front_chars]}...[{truncated_count} chars truncated]...{value[-tail_chars:]}"
elif isinstance(value, (dict, list)):
truncate_recursive(value)
elif isinstance(obj, list):
for item in obj:
truncate_recursive(item)
truncate_recursive(conv_copy)
return conv_copy
def keep_think_block_open(text_prompt: str) -> str:
"""Remove the closing of the final assistant message so the model can continue reasoning."""
think_open = ""
think_close = ""
last_open = text_prompt.rfind(think_open)
if last_open == -1:
return text_prompt
close_index = text_prompt.find(think_close, last_open)
if close_index == -1:
return text_prompt
# Skip any whitespace immediately following the closing tag
end_close = close_index + len(think_close)
while end_close < len(text_prompt) and text_prompt[end_close] in (' ', '\t'):
end_close += 1
while end_close < len(text_prompt) and text_prompt[end_close] in ('\r', '\n'):
end_close += 1
trimmed_prompt = text_prompt[:close_index] + text_prompt[end_close:]
debug('VQA caption: keep_think_block_open applied to prompt segment near assistant reply')
return trimmed_prompt
def b64(image):
if image is None:
return ''
with io.BytesIO() as stream:
image.save(stream, 'JPEG')
values = stream.getvalue()
encoded = base64.b64encode(values).decode()
return encoded
def clean(response, question, prefill=None):
strip = ['---', '\r', '\t', '**', '"', '"', '"', 'Assistant:', 'Caption:', '<|im_end|>', '']
if isinstance(response, str):
response = response.strip()
elif isinstance(response, dict):
text_response = ""
if 'reasoning' in response and get_keep_thinking():
r_text = response['reasoning']
if isinstance(r_text, dict) and 'text' in r_text:
r_text = r_text['text']
text_response += f"Reasoning:\n{r_text}\n\nAnswer:\n"
if 'answer' in response:
text_response += response['answer']
elif 'caption' in response:
text_response += response['caption']
elif 'task' in response:
text_response += response['task']
else:
if not text_response:
text_response = json.dumps(response)
response = text_response
elif isinstance(response, list):
response = response[0]
else:
response = str(response)
# Determine prefill text
prefill_text = vlm_prefill if prefill is None else prefill
if prefill_text is None:
prefill_text = ""
prefill_text = prefill_text.strip()
question = question.replace('<', '').replace('>', '').replace('_', ' ')
if question in response:
response = response.split(question, 1)[1]
while any(s in response for s in strip):
for s in strip:
response = response.replace(s, '')
response = response.replace(' ', ' ').replace('* ', '- ').strip()
# Handle prefill retention/removal
if get_keep_prefill():
# Add prefill if it's missing from the cleaned response
if len(prefill_text) > 0 and not response.startswith(prefill_text):
sep = " "
if not response or response[0] in ".,!?;:":
sep = ""
response = f"{prefill_text}{sep}{response}"
else:
# Remove prefill if it's present in the cleaned response
if len(prefill_text) > 0 and response.startswith(prefill_text):
response = response[len(prefill_text):].strip()
return response
def _get_overrides():
"""Get generation overrides from VQA singleton if available."""
if _instance is not None and _instance.generation_overrides is not None:
return _instance.generation_overrides
return {}
def get_keep_thinking():
"""Check if thinking trace should be kept, with per-request override support."""
overrides = _get_overrides()
if overrides.get('keep_thinking') is not None:
return overrides['keep_thinking']
return shared.opts.caption_vlm_keep_thinking
def strip_think_xml_tags(text: str, keep: bool = False) -> str:
"""Strip or reformat XML-style ... blocks from model output.
Applies to models that use HuggingFace chat templates with /
tokens (Qwen, Gemma, SmolVLM). Models with structured reasoning APIs
(e.g. Moondream) handle their reasoning output separately.
The opening tag is often in the prompt (not the response), so the
response may only contain without a matching .
Args:
text: Model output text potentially containing / tags.
keep: If True, reformat tags as human-readable Reasoning/Answer sections.
If False, strip thinking blocks entirely.
"""
if keep:
if '' in text and '' not in text:
text = 'Reasoning:\n' + text.replace('', '\n\nAnswer:')
else:
text = text.replace('', 'Reasoning:\n').replace('', '\n\nAnswer:')
else:
while '' in text:
start = text.find('')
end = text.find('')
if start != -1 and start < end:
text = text[:start] + text[end + 8:]
else:
text = text[end + 8:]
return text
def get_keep_prefill():
"""Check if prefill should be kept in output, with per-request override support."""
overrides = _get_overrides()
if overrides.get('keep_prefill') is not None:
return overrides['keep_prefill']
return shared.opts.caption_vlm_keep_prefill
def get_kwargs():
"""Build generation kwargs from settings with per-request overrides from VQA instance.
Checks the singleton VQA instance's generation_overrides for per-request overrides.
Override keys: max_tokens, temperature, top_k, top_p, num_beams, do_sample
None values are ignored, allowing selective override.
"""
# Get overrides from VQA singleton if available
overrides = _get_overrides()
# Get base values from settings, apply overrides if provided
max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.caption_vlm_max_length
do_sample = overrides.get('do_sample') if overrides.get('do_sample') is not None else shared.opts.caption_vlm_do_sample
num_beams = overrides.get('num_beams') if overrides.get('num_beams') is not None else shared.opts.caption_vlm_num_beams
temperature = overrides.get('temperature') if overrides.get('temperature') is not None else shared.opts.caption_vlm_temperature
top_k = overrides.get('top_k') if overrides.get('top_k') is not None else shared.opts.caption_vlm_top_k
top_p = overrides.get('top_p') if overrides.get('top_p') is not None else shared.opts.caption_vlm_top_p
kwargs = {
'max_new_tokens': max_tokens,
'do_sample': do_sample,
}
if num_beams > 0:
kwargs['num_beams'] = num_beams
if temperature > 0:
kwargs['temperature'] = temperature
if top_k > 0:
kwargs['top_k'] = top_k
if top_p > 0:
kwargs['top_p'] = top_p
return kwargs
class VQA:
"""Vision-Language Model interrogation class with per-model self-contained loading."""
def __init__(self):
self.processor = None
self.model = None
self.loaded: str = None
self.last_annotated_image = None
self.last_detection_data = None
self._generation_overrides = None # Per-request generation parameter overrides
@property
def generation_overrides(self):
"""Get current per-request generation parameter overrides."""
return self._generation_overrides
def unload(self):
"""Release VLM model from GPU/memory, including external handlers."""
if self.model is not None:
model_name = self.loaded
log.debug(f'VQA unload: unloading model="{model_name}"')
sd_models.move_model(self.model, devices.cpu, force=True)
self.model = None
self.processor = None
self.loaded = None
devices.torch_gc(force=True, reason='vqa unload')
log.debug(f'VQA unload: model="{model_name}" unloaded')
else:
log.debug('VQA unload: no internal model loaded')
# external handlers manage their own module-level globals and are not covered by self.model
from modules.caption import moondream3, joycaption, joytag, deepseek
moondream3.unload()
joycaption.unload()
joytag.unload()
deepseek.unload()
def load(self, model_name: str = None):
"""Load VLM model into memory for the specified model name."""
model_name = model_name or shared.opts.caption_vlm_model
if not model_name:
log.warning('VQA load: no model specified')
return
repo = vlm_models.get(model_name)
if repo is None:
log.error(f'VQA load: unknown model="{model_name}"')
return
log.debug(f'VQA load: pre-loading model="{model_name}" repo="{repo}"')
# dispatch to appropriate loader (same logic as caption)
repo_lower = repo.lower()
if 'qwen' in repo_lower or 'torii' in repo_lower or 'mimo' in repo_lower:
self._load_qwen(repo)
elif 'gemma' in repo_lower and 'pali' not in repo_lower:
self._load_gemma(repo)
elif 'smol' in repo_lower:
self._load_smol(repo)
elif 'florence' in repo_lower:
self._load_florence(repo)
elif 'moondream2' in repo_lower:
self._load_moondream(repo)
elif 'git' in repo_lower:
self._load_git(repo)
elif 'blip' in repo_lower:
self._load_blip(repo)
elif 'vilt' in repo_lower:
self._load_vilt(repo)
elif 'pix' in repo_lower:
self._load_pix(repo)
elif 'paligemma' in repo_lower:
self._load_paligemma(repo)
elif 'ovis' in repo_lower:
self._load_ovis(repo)
elif 'sa2' in repo_lower:
self._load_sa2(repo)
elif 'fastvlm' in repo_lower:
self._load_fastvlm(repo)
elif 'moondream3' in repo_lower:
from modules.caption import moondream3
moondream3.load_model(repo)
log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
elif 'joytag' in repo_lower:
from modules.caption import joytag
joytag.load()
log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
elif 'joycaption' in repo_lower:
from modules.caption import joycaption
joycaption.load(repo)
log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
elif 'deepseek' in repo_lower:
from modules.caption import deepseek
deepseek.load(repo)
log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
else:
# log.warning(f'VQA load: no pre-loader for model="{model_name}"')
return
def _load_fastvlm(self, repo: str):
"""Load FastVLM model and tokenizer."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
quant_args = model_quant.create_config(module='LLM')
self.model = None
self.processor = transformers.AutoTokenizer.from_pretrained(repo, trust_remote_code=True, cache_dir=shared.opts.hfcache_dir)
self.model = transformers.AutoModelForCausalLM.from_pretrained(
repo,
torch_dtype=devices.dtype,
trust_remote_code=True,
use_safetensors=True,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
self.loaded = repo
devices.torch_gc()
def _fastvlm(self, question: str, image: Image.Image, repo: str, model_name: str = None):
debug(f'VQA caption: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}')
self._load_fastvlm(repo)
sd_models.move_model(self.model, devices.device)
if len(question) < 2:
question = "Describe the image."
question = question.replace('<', '').replace('>', '')
IMAGE_TOKEN_INDEX = -200 # what the model code looks for
messages = [{"role": "user", "content": f"\n{question}"}]
rendered = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
pre, post = rendered.split("", 1)
pre_ids = self.processor(pre, return_tensors="pt", add_special_tokens=False).input_ids
post_ids = self.processor(post, return_tensors="pt", add_special_tokens=False).input_ids
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1)
input_ids = input_ids.to(devices.device)
attention_mask = torch.ones_like(input_ids, device=devices.device)
px = self.model.get_vision_tower().image_processor(images=image, return_tensors="pt")
px = px["pixel_values"].to(self.model.device, dtype=self.model.dtype)
with devices.inference_context():
outputs = self.model.generate(
inputs=input_ids,
attention_mask=attention_mask,
images=px,
max_new_tokens=128,
)
answer = self.processor.decode(outputs[0], skip_special_tokens=True)
return answer
# Map Qwen VL config model_type strings to their model classes.
_QWEN_VL_MODEL_TYPE_MAP = {
'qwen3_vl': 'Qwen3VLForConditionalGeneration',
'qwen2_5_vl': 'Qwen2_5_VLForConditionalGeneration',
'qwen2_vl': 'Qwen2VLForConditionalGeneration',
}
def _load_qwen(self, repo: str):
"""Load Qwen VL model and processor."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
self.model = None
if 'Qwen3-VL' in repo or 'Qwen3VL' in repo:
cls_name = transformers.Qwen3VLForConditionalGeneration
elif 'Qwen2.5-VL' in repo or 'Qwen2_5_VL' in repo or 'MiMo-VL' in repo:
cls_name = transformers.Qwen2_5_VLForConditionalGeneration
elif 'Qwen2-VL' in repo or 'Qwen2VL' in repo:
cls_name = transformers.Qwen2VLForConditionalGeneration
else:
# Fine-tunes (e.g. ToriiGate) may not have "Qwen" in the repo name.
# Detect the correct class from the config's model_type.
config = transformers.AutoConfig.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
model_type = getattr(config, 'model_type', '')
cls_attr = self._QWEN_VL_MODEL_TYPE_MAP.get(model_type)
cls_name = getattr(transformers, cls_attr) if cls_attr else transformers.AutoModelForCausalLM
quant_args = model_quant.create_config(module='LLM')
self.model = cls_name.from_pretrained(
repo,
torch_dtype=devices.dtype,
use_safetensors=True,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
if 'LLM' in shared.opts.cuda_compile:
self.model = sd_models_compile.compile_torch(self.model)
self.loaded = repo
devices.torch_gc()
def _qwen(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
self._load_qwen(repo)
sd_models.move_model(self.model, devices.device)
# Get model class name for logging
cls_name = self.model.__class__.__name__
debug(f'VQA caption: handler=qwen model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
question = question.replace('<', '').replace('>', '').replace('_', ' ')
system_prompt = system_prompt or shared.opts.caption_vlm_system
conversation = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
},
{
"role": "user",
"content": [
{"type": "image", "image": b64(image)},
{"type": "text", "text": question},
],
}
]
# Add prefill if provided
prefill_value = vlm_prefill if prefill is None else prefill
prefill_text = prefill_value.strip()
# Thinking models emit their own tags via the chat template
# Only models with thinking capability can use thinking mode
is_thinking = is_thinking_model(model_name)
# Standardize prefill
prefill_value = vlm_prefill if prefill is None else prefill
prefill_text = prefill_value.strip()
use_prefill = len(prefill_text) > 0
if debug_enabled:
debug(f'VQA caption: handler=qwen conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA caption: handler=qwen full_conversation={truncate_b64_in_conversation(conversation)}')
debug(f'VQA caption: handler=qwen is_thinking={is_thinking} thinking_mode={thinking_mode} prefill="{prefill_text}"')
# Generate base prompt using template
# Qwen-Thinking template automatically adds "<|im_start|>assistant\n\n" when add_generation_prompt=True
try:
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=True,
)
except (TypeError, ValueError) as e:
debug(f'VQA caption: handler=qwen chat_template fallback add_generation_prompt=True: {e}')
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
# Manually handle thinking tags and prefill
if is_thinking:
if not thinking_mode:
# User wants to SKIP thinking.
# Since template opened the block with , we close it immediately.
text_prompt += "\n"
if use_prefill:
text_prompt += prefill_text
else:
# User wants thinking. Prompt already ends in .
# If prefill is provided, it becomes part of the thought process.
if use_prefill:
text_prompt += prefill_text
else:
# Standard model (not forcing )
if use_prefill:
text_prompt += prefill_text
if debug_enabled:
debug(f'VQA caption: handler=qwen text_prompt="{text_prompt}"')
inputs = self.processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
gen_kwargs = get_kwargs()
debug(f'VQA caption: handler=qwen generation_kwargs={gen_kwargs} input_ids_shape={inputs.input_ids.shape}')
with devices.inference_context():
output_ids = self.model.generate(
**inputs,
**gen_kwargs,
)
debug(f'VQA caption: handler=qwen output_ids_shape={output_ids.shape}')
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids, strict=False)
]
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if debug_enabled:
debug(f'VQA caption: handler=qwen response_before_clean="{response}"')
if len(response) > 0:
response[0] = strip_think_xml_tags(response[0], keep=get_keep_thinking())
return response
def _load_gemma(self, repo: str):
"""Load Gemma 3 model and processor."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
self.model = None
if '3n' in repo:
cls = transformers.Gemma3nForConditionalGeneration # pylint: disable=no-member
else:
cls = transformers.Gemma3ForConditionalGeneration
quant_args = model_quant.create_config(module='LLM')
self.model = cls.from_pretrained(
repo,
torch_dtype=devices.dtype,
use_safetensors=True,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
if 'LLM' in shared.opts.cuda_compile:
self.model = sd_models_compile.compile_torch(self.model)
self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
devices.torch_gc()
def _gemma(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
self._load_gemma(repo)
sd_models.move_model(self.model, devices.device)
# Get model class name for logging
cls_name = self.model.__class__.__name__
debug(f'VQA caption: handler=gemma model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
question = question.replace('<', '').replace('>', '').replace('_', ' ')
system_prompt = system_prompt or shared.opts.caption_vlm_system
system_content = []
if system_prompt is not None and len(system_prompt) > 4:
system_content.append({"type": "text", "text": system_prompt})
user_content = []
if question is not None and len(question) > 4:
user_content.append({"type": "text", "text": question})
if image is not None:
user_content.append({"type": "image", "image": b64(image)})
conversation = [
{"role": "system", "content": system_content},
{"role": "user", "content": user_content},
]
# Add prefill if provided
prefill_value = vlm_prefill if prefill is None else prefill
prefill_text = prefill_value.strip()
use_prefill = len(prefill_text) > 0
# Thinking models emit their own tags via the chat template
# Use manual toggle OR auto-detection based on model name
use_thinking = thinking_mode or is_thinking_model(model_name)
if use_prefill:
conversation.append({
"role": "assistant",
"content": [{"type": "text", "text": prefill_text}],
})
debug(f'VQA caption: handler=gemma prefill="{prefill_text}"')
else:
debug('VQA caption: handler=gemma prefill disabled (empty), relying on add_generation_prompt')
if debug_enabled:
debug(f'VQA caption: handler=gemma conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA caption: handler=gemma full_conversation={truncate_b64_in_conversation(conversation)}')
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
debug(f'VQA caption: handler=gemma template_mode={debug_prefill_mode}')
try:
if use_prefill:
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=False,
continue_final_message=True,
tokenize=False,
)
else:
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=False,
)
except (TypeError, ValueError) as e:
debug(f'VQA caption: handler=gemma chat_template fallback add_generation_prompt=True: {e}')
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=False,
)
if use_prefill and use_thinking:
text_prompt = keep_think_block_open(text_prompt)
if debug_enabled:
debug(f'VQA caption: handler=gemma text_prompt="{text_prompt}"')
inputs = self.processor(
text=[text_prompt],
images=[image],
padding=True,
return_tensors="pt",
).to(device=devices.device, dtype=devices.dtype)
input_len = inputs["input_ids"].shape[-1]
gen_kwargs = get_kwargs()
debug(f'VQA caption: handler=gemma generation_kwargs={gen_kwargs} input_len={input_len}')
with devices.inference_context():
generation = self.model.generate(
**inputs,
**gen_kwargs,
)
debug(f'VQA caption: handler=gemma output_ids_shape={generation.shape}')
generation = generation[0][input_len:]
response = self.processor.decode(generation, skip_special_tokens=True)
if debug_enabled:
debug(f'VQA caption: handler=gemma response_before_clean="{response}"')
response = strip_think_xml_tags(response, keep=get_keep_thinking())
return response
def _load_paligemma(self, repo: str):
"""Load PaliGemma model and processor."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
self.processor = transformers.PaliGemmaProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.model = None
self.model = transformers.PaliGemmaForConditionalGeneration.from_pretrained(
repo,
cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype,
use_safetensors=True,
)
self.loaded = repo
devices.torch_gc()
def _paligemma(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_paligemma(repo)
sd_models.move_model(self.model, devices.device)
question = question.replace('<', '').replace('>', '').replace('_', ' ')
model_inputs = self.processor(text=question, images=image, return_tensors="pt").to(devices.device, devices.dtype)
input_len = model_inputs["input_ids"].shape[-1]
with devices.inference_context():
generation = self.model.generate(
**model_inputs,
**get_kwargs(),
)
generation = generation[0][input_len:]
response = self.processor.decode(generation, skip_special_tokens=True)
return response
def _load_ovis(self, repo: str):
"""Load Ovis model (requires flash-attn)."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
self.model = None
# Ovis remote code calls AutoConfig.register("aimv2", ...) at module scope
# without exist_ok=True, which fails on reload or when the type is already
# registered by a newer transformers version.
_orig = transformers.AutoConfig.register.__func__ if hasattr(transformers.AutoConfig.register, '__func__') else transformers.AutoConfig.register
transformers.AutoConfig.register = staticmethod(lambda model_type, config, exist_ok=False: _orig(model_type, config, exist_ok=True))
try:
self.model = transformers.AutoModelForCausalLM.from_pretrained(
repo,
torch_dtype=devices.dtype,
multimodal_max_length=32768,
trust_remote_code=True,
use_safetensors=True,
cache_dir=shared.opts.hfcache_dir,
)
finally:
transformers.AutoConfig.register = _orig
self.loaded = repo
devices.torch_gc()
def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
try:
pass # pylint: disable=unused-import
except Exception:
log.error(f'Caption: vlm="{repo}" flash-attn is not available')
return ''
self._load_ovis(repo)
sd_models.move_model(self.model, devices.device)
text_tokenizer = self.model.get_text_tokenizer()
visual_tokenizer = self.model.get_visual_tokenizer()
max_partition = 9
question = question.replace('<', '').replace('>', '').replace('_', ' ')
question = f'\n{question}'
_prompt, input_ids, pixel_values = self.model.preprocess_inputs(question, [image], max_partition=max_partition)
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
input_ids = input_ids.unsqueeze(0).to(device=self.model.device)
attention_mask = attention_mask.unsqueeze(0).to(device=self.model.device)
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)
pixel_values = [pixel_values]
with devices.inference_context():
output_ids = self.model.generate(
input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
repetition_penalty=None,
eos_token_id=self.model.generation_config.eos_token_id,
pad_token_id=text_tokenizer.pad_token_id,
use_cache=True,
**get_kwargs())
response = text_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
def _load_smol(self, repo: str):
"""Load SmolVLM model and processor."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
self.model = None
quant_args = model_quant.create_config(module='LLM')
self.model = transformers.AutoModelForVision2Seq.from_pretrained(
repo,
cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype,
use_safetensors=True,
**quant_args,
)
self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
if 'LLM' in shared.opts.cuda_compile:
self.model = sd_models_compile.compile_torch(self.model)
self.loaded = repo
devices.torch_gc()
def _smol(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
self._load_smol(repo)
sd_models.move_model(self.model, devices.device)
# Get model class name for logging
cls_name = self.model.__class__.__name__
debug(f'VQA caption: handler=smol model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
question = question.replace('<', '').replace('>', '').replace('_', ' ')
system_prompt = system_prompt or shared.opts.caption_vlm_system
conversation = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
},
{
"role": "user",
"content": [
{"type": "image", "image": b64(image)},
{"type": "text", "text": question},
],
}
]
# Add prefill if provided
prefill_value = vlm_prefill if prefill is None else prefill
prefill_text = prefill_value.strip()
use_prefill = len(prefill_text) > 0
# Thinking models emit their own tags via the chat template
# Use manual toggle OR auto-detection based on model name
use_thinking = thinking_mode or is_thinking_model(model_name)
if use_prefill:
conversation.append({
"role": "assistant",
"content": [{"type": "text", "text": prefill_text}],
})
debug(f'VQA caption: handler=smol prefill="{prefill_text}"')
else:
debug('VQA caption: handler=smol prefill disabled (empty), relying on add_generation_prompt')
if debug_enabled:
debug(f'VQA caption: handler=smol conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA caption: handler=smol full_conversation={truncate_b64_in_conversation(conversation)}')
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
debug(f'VQA caption: handler=smol template_mode={debug_prefill_mode}')
try:
if use_prefill:
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=False,
continue_final_message=True,
)
else:
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
except (TypeError, ValueError) as e:
debug(f'VQA caption: handler=smol chat_template fallback add_generation_prompt=True: {e}')
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
if use_prefill and use_thinking:
text_prompt = keep_think_block_open(text_prompt)
if debug_enabled:
debug(f'VQA caption: handler=smol text_prompt="{text_prompt}"')
inputs = self.processor(text=text_prompt, images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
gen_kwargs = get_kwargs()
debug(f'VQA caption: handler=smol generation_kwargs={gen_kwargs}')
with devices.inference_context():
output_ids = self.model.generate(
**inputs,
**gen_kwargs,
)
debug(f'VQA caption: handler=smol output_ids_shape={output_ids.shape}')
response = self.processor.batch_decode(output_ids, skip_special_tokens=True)
if debug_enabled:
debug(f'VQA caption: handler=smol response_before_clean="{response}"')
if len(response) > 0:
response[0] = strip_think_xml_tags(response[0], keep=get_keep_thinking())
return response
def _load_git(self, repo: str):
"""Load Microsoft GIT model and processor."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.GitForCausalLM.from_pretrained(
repo,
torch_dtype=devices.dtype,
use_safetensors=True,
cache_dir=shared.opts.hfcache_dir,
)
self.processor = transformers.GitProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
devices.torch_gc()
def _git(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_git(repo)
sd_models.move_model(self.model, devices.device)
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
git_dict = {}
git_dict['pixel_values'] = pixel_values.to(devices.device, devices.dtype)
if len(question) > 0:
input_ids = self.processor(text=question, add_special_tokens=False).input_ids
input_ids = [self.processor.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
git_dict['input_ids'] = input_ids.to(devices.device)
with devices.inference_context():
generated_ids = self.model.generate(**git_dict)
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
def _load_blip(self, repo: str):
"""Load Salesforce BLIP model and processor."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.BlipForQuestionAnswering.from_pretrained(
repo,
torch_dtype=devices.dtype,
use_safetensors=True,
cache_dir=shared.opts.hfcache_dir,
)
self.processor = transformers.BlipProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
devices.torch_gc()
def _blip(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_blip(repo)
sd_models.move_model(self.model, devices.device)
inputs = self.processor(image, question, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
with devices.inference_context():
outputs = self.model.generate(**inputs)
response = self.processor.decode(outputs[0], skip_special_tokens=True)
return response
def _load_vilt(self, repo: str):
"""Load ViLT model and processor."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.ViltForQuestionAnswering.from_pretrained(
repo,
torch_dtype=devices.dtype,
use_safetensors=True,
cache_dir=shared.opts.hfcache_dir,
)
self.processor = transformers.ViltProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
devices.torch_gc()
def _vilt(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_vilt(repo)
sd_models.move_model(self.model, devices.device)
inputs = self.processor(image, question, return_tensors="pt")
inputs = inputs.to(devices.device)
with devices.inference_context():
outputs = self.model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
response = self.model.config.id2label[idx]
return response
def _load_pix(self, repo: str):
"""Load Pix2Struct model and processor."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.Pix2StructForConditionalGeneration.from_pretrained(
repo,
torch_dtype=devices.dtype,
use_safetensors=True,
cache_dir=shared.opts.hfcache_dir,
)
self.processor = transformers.Pix2StructProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
devices.torch_gc()
def _pix(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_pix(repo)
sd_models.move_model(self.model, devices.device)
if len(question) > 0:
inputs = self.processor(images=image, text=question, return_tensors="pt")
else:
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(devices.device, devices.dtype) if v.is_floating_point() else v.to(devices.device) for k, v in inputs.items()}
with devices.inference_context():
outputs = self.model.generate(**inputs)
response = self.processor.decode(outputs[0], skip_special_tokens=True)
return response
def _load_moondream(self, repo: str):
"""Load Moondream 2 model and tokenizer."""
if self.model is None or self.loaded != repo:
log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.AutoModelForCausalLM.from_pretrained(
repo,
revision="2025-06-21",
trust_remote_code=True,
torch_dtype=devices.dtype,
use_safetensors=True,
cache_dir=shared.opts.hfcache_dir,
)
self.processor = transformers.AutoTokenizer.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
self.model.eval() # required: trust_remote_code model
devices.torch_gc()
def _moondream(self, question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False):
debug(f'VQA caption: handler=moondream model_name="{model_name}" repo="{repo}" question="{question}" thinking_mode={thinking_mode}')
self._load_moondream(repo)
sd_models.move_model(self.model, devices.device)
question = question.replace('<', '').replace('>', '').replace('_', ' ')
with devices.inference_context():
if question == 'CAPTION':
response = self.model.caption(image, length="short")['caption']
elif question == 'DETAILED CAPTION':
response = self.model.caption(image, length="normal")['caption']
elif question == 'MORE DETAILED CAPTION':
response = self.model.caption(image, length="long")['caption']
elif question.lower().startswith('point at ') or question == 'POINT_MODE':
target = question[9:].strip() if question.lower().startswith('point at ') else ''
if not target:
return "Please specify an object to locate"
debug(f'VQA caption: handler=moondream method=point target="{target}"')
result = self.model.point(image, target)
debug(f'VQA caption: handler=moondream point_raw_result={result}')
points = vqa_detection.parse_points(result)
if points:
self.last_detection_data = {'points': points}
return vqa_detection.format_points_text(points)
return "Object not found"
elif question == 'DETECT_GAZE' or question.lower() == 'detect gaze':
# Must be checked before generic 'detect ' prefix to avoid matching as detect target="Gaze"
debug('VQA caption: handler=moondream method=detect_gaze')
faces = self.model.detect(image, "face")
debug(f'VQA caption: handler=moondream detect_gaze faces={faces}')
if faces.get('objects'):
eye_x, eye_y = vqa_detection.calculate_eye_position(faces['objects'][0])
result = self.model.detect_gaze(image, eye=(eye_x, eye_y))
debug(f'VQA caption: handler=moondream detect_gaze result={result}')
if result.get('gaze'):
gaze = result['gaze']
self.last_detection_data = {'points': [(gaze['x'], gaze['y'])]}
return f"Gaze direction: ({gaze['x']:.3f}, {gaze['y']:.3f})"
return "No face/gaze detected"
elif question.lower().startswith('detect ') or question == 'DETECT_MODE':
target = question[7:].strip() if question.lower().startswith('detect ') else ''
if not target:
return "Please specify an object to detect"
debug(f'VQA caption: handler=moondream method=detect target="{target}"')
result = self.model.detect(image, target)
debug(f'VQA caption: handler=moondream detect_raw_result={result}')
detections = vqa_detection.parse_detections(result, target)
if detections:
self.last_detection_data = {'detections': detections}
return vqa_detection.format_detections_text(detections, include_confidence=False)
return "No objects detected"
else:
debug(f'VQA caption: handler=moondream method=query question="{question}" reasoning={thinking_mode}')
result = self.model.query(image, question, reasoning=thinking_mode)
response = result['answer']
debug(f'VQA caption: handler=moondream query_result keys={list(result.keys()) if isinstance(result, dict) else "not dict"}')
if thinking_mode and 'reasoning' in result:
reasoning_text = result['reasoning'].get('text', '') if isinstance(result['reasoning'], dict) else str(result['reasoning'])
debug(f'VQA caption: handler=moondream reasoning_text="{reasoning_text[:100]}..."')
if get_keep_thinking():
response = f"Reasoning:\n{reasoning_text}\n\nAnswer:\n{response}"
# When keep_thinking is False, just use the answer (reasoning is discarded)
return response
def _load_florence(self, repo: str, revision: str = None):
"""Load Florence-2 model and processor."""
_get_imports = transformers.dynamic_module_utils.get_imports
def get_imports(f):
R = _get_imports(f)
if "flash_attn" in R:
R.remove("flash_attn") # flash_attn is optional
return R
# Handle revision splitting and caching
cache_key = repo
effective_revision = revision
repo_name = repo
if repo and '@' in repo:
repo_name, revision_from_repo = repo.split('@')
effective_revision = revision_from_repo
if self.model is None or self.loaded != cache_key:
log.debug(f'Caption load: vlm="{repo_name}" revision="{effective_revision}" path="{shared.opts.hfcache_dir}"')
transformers.dynamic_module_utils.get_imports = get_imports
self.model = None
quant_args = model_quant.create_config(module='LLM')
self.model = transformers.Florence2ForConditionalGeneration.from_pretrained(
repo_name,
revision=effective_revision,
torch_dtype=devices.dtype,
use_safetensors=True,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
self.processor = transformers.AutoProcessor.from_pretrained(repo_name, max_pixels=1024*1024, trust_remote_code=True, revision=effective_revision, cache_dir=shared.opts.hfcache_dir)
transformers.dynamic_module_utils.get_imports = _get_imports
self.loaded = cache_key
devices.torch_gc()
def _florence(self, question: str, image: Image.Image, repo: str, revision: str = None, model_name: str = None): # pylint: disable=unused-argument
self._load_florence(repo, revision)
sd_models.move_model(self.model, devices.device)
if question.startswith('<'):
task = question.split('>', 1)[0] + '>'
else:
task = ''
debug(f'VQA caption: handler=florence model_name="{model_name}" repo="{repo}" task="{task}" question="{question}" image_size={image.size}')
inputs = self.processor(text=task, images=image, return_tensors="pt")
input_ids = inputs['input_ids'].to(devices.device)
pixel_values = inputs['pixel_values'].to(devices.device, devices.dtype)
debug(f'VQA caption: handler=florence input_ids={input_ids.shape} pixel_values={pixel_values.shape} dtype={pixel_values.dtype}')
# Florence-2 requires beam search, not sampling - sampling causes probability tensor errors
overrides = _get_overrides()
max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.caption_vlm_max_length
gen_kwargs = {'max_new_tokens': max_tokens, 'num_beams': 3, 'do_sample': False}
# Some Florence fine-tunes (e.g., CogFlorence) don't have decoder_start_token_id set
if getattr(self.model.config, 'decoder_start_token_id', None) is None:
bos_token_id = getattr(self.processor.tokenizer, 'bos_token_id', None) or 0
gen_kwargs['decoder_start_token_id'] = bos_token_id
debug(f'VQA caption: handler=florence setting decoder_start_token_id={bos_token_id}')
debug(f'VQA caption: handler=florence generation_kwargs={gen_kwargs}')
with devices.inference_context(), devices.bypass_sdpa_hijacks():
generated_ids = self.model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
**gen_kwargs,
)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
debug(f'VQA caption: handler=florence generated_text="{generated_text}"')
# task="task" is intentional: produces {'task': text} which both parse_florence_detections and
# format_florence_response handle via explicit 'task' key fallbacks, avoiding task-token-specific keys
response = self.processor.post_process_generation(generated_text, task="task", image_size=(image.width, image.height))
debug(f'VQA caption: handler=florence raw_response={response}')
return response
def _load_sa2(self, repo: str):
"""Load SA2VA model and tokenizer."""
if self.model is None or self.loaded != repo:
self.model = None
self.model = transformers.AutoModel.from_pretrained(
repo,
torch_dtype=devices.dtype,
low_cpu_mem_usage=True,
use_flash_attn=False,
use_safetensors=True,
trust_remote_code=True,
cache_dir=shared.opts.hfcache_dir)
self.model = self.model.eval() # required: trust_remote_code model
self.processor = transformers.AutoTokenizer.from_pretrained(
repo,
trust_remote_code=True,
use_fast=False,
cache_dir=shared.opts.hfcache_dir,
)
self.loaded = repo
devices.torch_gc()
def _sa2(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_sa2(repo)
sd_models.move_model(self.model, devices.device)
if question.startswith('<'):
task = question.split('>', 1)[0] + '>'
else:
task = ''
input_dict = {
'image': image,
'text': f'{task}',
'past_text': '',
'mask_prompts': None,
'tokenizer': self.processor,
}
with devices.inference_context():
return_dict = self.model.predict_forward(**input_dict)
response = return_dict["prediction"] # the text format answer
return response
def caption(self, question: str = '', system_prompt: str = None, prompt: str = None, image: Image.Image = None, model_name: str = None, prefill: str = None, thinking_mode: bool = None, quiet: bool = False, generation_kwargs: dict = None) -> str:
"""
Main entry point for VQA captioning. Returns string answer.
Detection data stored in self.last_detection_data for annotated image creation.
Args:
question: Question/task to perform
system_prompt: System prompt for the model
prompt: Additional prompt text
image: PIL Image to process
model_name: Model to use (defaults to settings)
prefill: Text to prefill the response with
thinking_mode: Enable thinking/reasoning mode (None = use settings)
quiet: Suppress logging
generation_kwargs: Optional dict with generation parameter overrides:
max_tokens, temperature, top_k, top_p, num_beams, do_sample, keep_thinking, keep_prefill
"""
self.last_annotated_image = None
self.last_detection_data = None
self._generation_overrides = generation_kwargs # Set per-request overrides
jobid = shared.state.begin('Caption LLM')
t0 = time.time()
model_name = model_name or shared.opts.caption_vlm_model
prefill = vlm_prefill if prefill is None else prefill # Use provided prefill when specified
thinking_mode = shared.opts.caption_vlm_thinking_mode if thinking_mode is None else thinking_mode # Resolve from settings if not specified
if isinstance(image, list):
image = image[0] if len(image) > 0 else None
if isinstance(image, dict) and 'name' in image:
image = Image.open(image['name'])
if isinstance(image, Image.Image):
if image.width > 768 or image.height > 768:
image.thumbnail((768, 768), Image.Resampling.LANCZOS)
if image.mode != 'RGB':
image = image.convert('RGB')
if image is None:
log.error(f'VQA caption: model="{model_name}" error="No input image provided"')
self._generation_overrides = None
shared.state.end(jobid)
return 'Error: No input image provided. Please upload or select an image.'
# Convert friendly prompt names to internal tokens/commands
if question == "Use Prompt":
# Use content from Prompt field directly - requires user input
if not prompt or len(prompt.strip()) < 2:
log.error(f'VQA caption: model="{model_name}" error="Please enter a prompt"')
self._generation_overrides = None
shared.state.end(jobid)
return 'Error: Please enter a question or instruction in the Prompt field.'
question = prompt
elif question in vlm_prompt_mapping:
# Check if this is a mode that requires user input (Point/Detect)
raw_mapping = vlm_prompt_mapping.get(question)
if raw_mapping in ("POINT_MODE", "DETECT_MODE"):
# These modes require user input in the prompt field
if not prompt or len(prompt.strip()) < 2:
log.error(f'VQA caption: model="{model_name}" error="Please specify what to find in the prompt field"')
self._generation_overrides = None
shared.state.end(jobid)
return 'Error: Please specify what to find in the prompt field (e.g., "the red car" or "faces").'
# Convert friendly name to internal token (handles Point/Detect prefix)
question = get_internal_prompt(question, prompt)
# else: question is already an internal token or custom text
from modules import modelloader
modelloader.hf_login()
try:
if model_name is None:
log.error(f'Caption: type=vlm model="{model_name}" no model selected')
shared.state.end(jobid)
return ''
vqa_model = vlm_models.get(model_name, None)
if vqa_model is None:
log.error(f'Caption: type=vlm model="{model_name}" unknown')
shared.state.end(jobid)
return ''
handler = 'unknown'
if 'git' in vqa_model.lower():
handler = 'git'
answer = self._git(question, image, vqa_model, model_name)
elif 'vilt' in vqa_model.lower():
handler = 'vilt'
answer = self._vilt(question, image, vqa_model, model_name)
elif 'blip' in vqa_model.lower():
handler = 'blip'
answer = self._blip(question, image, vqa_model, model_name)
elif 'pix' in vqa_model.lower():
handler = 'pix'
answer = self._pix(question, image, vqa_model, model_name)
elif 'moondream3' in vqa_model.lower():
handler = 'moondream3'
from modules.caption import moondream3
answer = moondream3.predict(question, image, vqa_model, model_name, thinking_mode=thinking_mode)
elif 'moondream2' in vqa_model.lower():
handler = 'moondream'
answer = self._moondream(question, image, vqa_model, model_name, thinking_mode)
elif 'florence' in vqa_model.lower():
handler = 'florence'
answer = self._florence(question, image, vqa_model, None, model_name)
# Parse Florence detection response for annotated image (handles both dict and string formats)
florence_detections = vqa_detection.parse_florence_detections(answer, image.size if image else None)
if florence_detections:
self.last_detection_data = {'detections': florence_detections}
debug(f'VQA caption: handler=florence parsed {len(florence_detections)} detections')
# Format dict answer as readable string (string answers pass through unchanged)
if isinstance(answer, dict):
answer = vqa_detection.format_florence_response(answer)
elif 'qwen' in vqa_model.lower() or 'torii' in vqa_model.lower() or 'mimo' in vqa_model.lower():
handler = 'qwen'
answer = self._qwen(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
elif 'smol' in vqa_model.lower():
handler = 'smol'
answer = self._smol(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
elif 'joytag' in vqa_model.lower():
handler = 'joytag'
from modules.caption import joytag
answer = joytag.predict(image)
elif 'joycaption' in vqa_model.lower():
handler = 'joycaption'
from modules.caption import joycaption
answer = joycaption.predict(question, image, vqa_model)
elif 'deepseek' in vqa_model.lower():
handler = 'deepseek'
from modules.caption import deepseek
answer = deepseek.predict(question, image, vqa_model)
elif 'paligemma' in vqa_model.lower():
handler = 'paligemma'
answer = self._paligemma(question, image, vqa_model, model_name)
elif 'gemma' in vqa_model.lower():
handler = 'gemma'
answer = self._gemma(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
elif 'ovis' in vqa_model.lower():
handler = 'ovis'
answer = self._ovis(question, image, vqa_model, model_name)
elif 'sa2' in vqa_model.lower():
handler = 'sa2'
answer = self._sa2(question, image, vqa_model, model_name)
elif 'fastvlm' in vqa_model.lower():
handler = 'fastvlm'
answer = self._fastvlm(question, image, vqa_model, model_name)
elif 'gemini' in vqa_model.lower():
handler = 'gemini'
gen_kwargs = get_kwargs()
from modules.caption import gemini
answer = gemini.predict(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode, gen_kwargs)
else:
answer = 'unknown model'
except Exception as e:
errors.display(e, 'VQA')
answer = 'error'
if shared.opts.caption_offload and self.model is not None:
sd_models.move_model(self.model, devices.cpu, force=True)
devices.torch_gc(force=True, reason='vqa')
# Clean the answer
answer = clean(answer, question, prefill)
# Create annotated image if detection data is available
if self.last_detection_data and isinstance(self.last_detection_data, dict) and image:
detections = self.last_detection_data.get('detections', None)
points = self.last_detection_data.get('points', None)
if detections or points:
self.last_annotated_image = vqa_detection.draw_bounding_boxes(image, detections or [], points)
debug(f'VQA caption: handler={handler} created annotated image detections={len(detections) if detections else 0} points={len(points) if points else 0}')
debug(f'VQA caption: handler={handler} response="{answer}" annotation={self.last_annotated_image is not None}')
t1 = time.time()
if not quiet:
model_name = model_name.split(' ')[0] if model_name else 'None'
log.debug(f'Caption: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}')
self._generation_overrides = None # Clear per-request overrides
shared.state.end(jobid)
return answer
def batch(self, model_name, system_prompt, batch_files, batch_folder, batch_str, question, prompt, write, append, recursive, prefill=None, thinking_mode=False):
class BatchWriter:
def __init__(self, folder, mode='w'):
self.folder = folder
self.csv = None
self.file = None
self.mode = mode
def add(self, file, prompt_text):
txt_file = os.path.splitext(file)[0] + ".txt"
if self.mode == 'a':
prompt_text = '\n' + prompt_text
with open(os.path.join(self.folder, txt_file), self.mode, encoding='utf-8') as f:
f.write(prompt_text)
def close(self):
if self.file is not None:
self.file.close()
files = []
if batch_files is not None:
files += [f.name for f in batch_files]
if batch_folder is not None:
files += [f.name for f in batch_folder]
if batch_str is not None and len(batch_str) > 0 and os.path.exists(batch_str) and os.path.isdir(batch_str):
from modules.files_cache import list_files
files += list(list_files(batch_str, ext_filter=['.png', '.jpg', '.jpeg', '.webp', '.jxl'], recursive=recursive))
if len(files) == 0:
log.warning('Caption batch: type=vlm no images')
return ''
jobid = shared.state.begin('Caption batch')
prompts = []
if write:
mode = 'w' if not append else 'a'
writer = BatchWriter(os.path.dirname(files[0]), mode=mode)
orig_offload = shared.opts.caption_offload
shared.opts.caption_offload = False
try:
import rich.progress as rp
pbar = rp.Progress(rp.TextColumn('[cyan]Caption:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=console)
with pbar:
task = pbar.add_task(total=len(files), description='starting...')
for file in files:
pbar.update(task, advance=1, description=file)
try:
if shared.state.interrupted:
break
img = Image.open(file)
result = self.caption(question, system_prompt, prompt, img, model_name, prefill, thinking_mode, quiet=True)
# Save annotated image if available
if self.last_annotated_image and write:
annotated_path = os.path.splitext(file)[0] + "_annotated.png"
self.last_annotated_image.save(annotated_path)
prompts.append(result)
if write:
writer.add(file, result)
except Exception as e:
log.error(f'Caption batch: {e}')
if write:
writer.close()
finally:
shared.opts.caption_offload = orig_offload
shared.state.end(jobid)
return '\n\n'.join(prompts)
# Module-level singleton instance
_instance = None
def get_instance() -> VQA:
"""Get or create the singleton VQA instance."""
global _instance # pylint: disable=global-statement
if _instance is None:
_instance = VQA()
return _instance
# Backwards-compatible module-level functions
def caption(*args, **kwargs):
return get_instance().caption(*args, **kwargs)
def unload_model():
return get_instance().unload()
def load_model(model_name: str = None):
return get_instance().load(model_name)
def get_last_annotated_image():
return get_instance().last_annotated_image
def batch(*args, **kwargs):
return get_instance().batch(*args, **kwargs)