import io import os import re import time import json import base64 import copy import torch import transformers import transformers.dynamic_module_utils from PIL import Image from modules import shared, devices, errors, model_quant, sd_models, sd_models_compile from modules.sd_offload_aux import register_aux, deregister_aux, move_aux_to_gpu, offload_aux from modules.logger import log, console from modules.caption import vqa_detection from modules.caption.models_def import vlm_models, vlm_system, vlm_default, vlm_prefill, vlm_prompts, vlm_prompt_mapping, vlm_prompt_placeholders, vlm_prompts_common, vlm_prompts_florence, vlm_prompts_moondream, vlm_prompts_moondream2, vlm_prompts_promptgen, get_vlm_repo # Debug logging - function-based to avoid circular import debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None def debug(*args, **kwargs): if debug_enabled: log.trace(*args, **kwargs) def get_prompts_for_model(model_name: str) -> list: """Get available prompts based on selected model.""" if model_name is None: return vlm_prompts_common model_lower = model_name.lower() # Check for PromptGen models (MiaoshouAI fine-tunes with extra prompts) if 'promptgen' in model_lower: return vlm_prompts_common + vlm_prompts_florence + vlm_prompts_promptgen # Check for Florence-2 base / CogFlorence models (no PromptGen-specific prompts) if 'florence' in model_lower: return vlm_prompts_common + vlm_prompts_florence # Check for Moondream models (Moondream 2 has gaze detection, Moondream 3 does not) if 'moondream' in model_lower: if 'moondream3' in model_lower or 'moondream 3' in model_lower: return vlm_prompts_common + vlm_prompts_moondream else: # Moondream 2 includes gaze detection return vlm_prompts_common + vlm_prompts_moondream + vlm_prompts_moondream2 # Default: common prompts only return vlm_prompts_common def get_internal_prompt(friendly_name: str, user_prompt: str = None) -> str: """Convert friendly prompt name to internal token/command.""" internal = vlm_prompt_mapping.get(friendly_name, friendly_name) # Handle Moondream point/detect modes - prepend trigger phrase if internal == "POINT_MODE" and user_prompt: return f"Point at {user_prompt}" elif internal == "DETECT_MODE" and user_prompt: return f"Detect {user_prompt}" return internal def get_prompt_placeholder(friendly_name: str) -> str: """Get placeholder text for the prompt field based on selected question.""" return vlm_prompt_placeholders.get(friendly_name, "Enter your question or instruction") def is_florence_task(question: str) -> bool: """Check if the question is a Florence-2 task token (either friendly name or internal token). This includes both base Florence prompts and PromptGen-specific prompts, since all are handled by the Florence handler. """ if not question: return False # Check if it's a Florence-specific friendly name (base or PromptGen) if question in vlm_prompts_florence or question in vlm_prompts_promptgen: return True # Check if it's an internal Florence-2 task token (for backwards compatibility) florence_base_tokens = ['', '', '', '', '', '', '', '', ''] promptgen_tokens = ['', '', '', ''] return question in florence_base_tokens or question in promptgen_tokens def is_thinking_model(model_name: str) -> bool: """Check if the model supports thinking mode based on its name.""" if not model_name: return False model_lower = model_name.lower() # Check for known thinking models thinking_indicators = [ 'thinking', # Qwen3-VL-*-Thinking models 'reasoning', # Mistral-3-*-Reasoning models 'moondream3', # Moondream 3 supports thinking 'moondream 3', 'moondream2', # Moondream 2 supports reasoning mode 'moondream 2', 'mimo', 'qwen3.5', # Qwen3.5 native thinking (repo names) 'qwen 3.5', # Qwen3.5 native thinking (display names) ] return any(indicator in model_lower for indicator in thinking_indicators) def truncate_b64_in_conversation(conversation, front_chars=50, tail_chars=50, threshold=200): """ Deep copy a conversation structure and truncate long base64 image strings for logging. Preserves front and tail of base64 strings with truncation indicator. """ conv_copy = copy.deepcopy(conversation) def truncate_recursive(obj): if isinstance(obj, dict): for key, value in obj.items(): if key == "image" and isinstance(value, str) and len(value) > threshold: # Truncate the base64 image string truncated_count = len(value) - front_chars - tail_chars obj[key] = f"{value[:front_chars]}...[{truncated_count} chars truncated]...{value[-tail_chars:]}" elif isinstance(value, (dict, list)): truncate_recursive(value) elif isinstance(obj, list): for item in obj: truncate_recursive(item) truncate_recursive(conv_copy) return conv_copy def keep_think_block_open(text_prompt: str) -> str: """Remove the closing of the final assistant message so the model can continue reasoning.""" think_open = "" think_close = "" last_open = text_prompt.rfind(think_open) if last_open == -1: return text_prompt close_index = text_prompt.find(think_close, last_open) if close_index == -1: return text_prompt # Skip any whitespace immediately following the closing tag end_close = close_index + len(think_close) while end_close < len(text_prompt) and text_prompt[end_close] in (' ', '\t'): end_close += 1 while end_close < len(text_prompt) and text_prompt[end_close] in ('\r', '\n'): end_close += 1 trimmed_prompt = text_prompt[:close_index] + text_prompt[end_close:] debug('VQA caption: keep_think_block_open applied to prompt segment near assistant reply') return trimmed_prompt def b64(image): if image is None: return '' with io.BytesIO() as stream: image.save(stream, 'JPEG') values = stream.getvalue() encoded = base64.b64encode(values).decode() return encoded def clean(response, question, prefill=None): strip = ['---', '\r', '\t', '**', '"', '"', '"', 'Assistant:', 'Caption:', '<|im_end|>', ''] if isinstance(response, str): response = response.strip() elif isinstance(response, dict): text_response = "" if 'reasoning' in response and get_keep_thinking(): r_text = response['reasoning'] if isinstance(r_text, dict) and 'text' in r_text: r_text = r_text['text'] text_response += f"Reasoning:\n{r_text}\n\nAnswer:\n" if 'answer' in response: text_response += response['answer'] elif 'caption' in response: text_response += response['caption'] elif 'task' in response: text_response += response['task'] else: if not text_response: text_response = json.dumps(response) response = text_response elif isinstance(response, list): response = response[0] else: response = str(response) # Determine prefill text prefill_text = vlm_prefill if prefill is None else prefill if prefill_text is None: prefill_text = "" prefill_text = prefill_text.strip() question = question.replace('<', '').replace('>', '').replace('_', ' ') if question in response: response = response.split(question, 1)[1] while any(s in response for s in strip): for s in strip: response = response.replace(s, '') response = response.replace(' ', ' ').replace('* ', '- ').strip() # Handle prefill retention/removal if get_keep_prefill(): # Add prefill if it's missing from the cleaned response if len(prefill_text) > 0 and not response.startswith(prefill_text): sep = " " if not response or response[0] in ".,!?;:": sep = "" response = f"{prefill_text}{sep}{response}" else: # Remove prefill if it's present in the cleaned response if len(prefill_text) > 0 and response.startswith(prefill_text): response = response[len(prefill_text):].strip() return response def _get_overrides(): """Get generation overrides from VQA singleton if available.""" if _instance is not None and _instance.generation_overrides is not None: return _instance.generation_overrides return {} def get_keep_thinking(): """Check if thinking trace should be kept, with per-request override support.""" overrides = _get_overrides() if overrides.get('keep_thinking') is not None: return overrides['keep_thinking'] return shared.opts.caption_vlm_keep_thinking def strip_think_xml_tags(text: str, keep: bool = False) -> str: """Strip or reformat XML-style ... blocks from model output. Applies to models that use HuggingFace chat templates with / tokens (Qwen, Gemma, SmolVLM). Models with structured reasoning APIs (e.g. Moondream) handle their reasoning output separately. The opening tag is often in the prompt (not the response), so the response may only contain without a matching . Args: text: Model output text potentially containing / tags. keep: If True, reformat tags as human-readable Reasoning/Answer sections. If False, strip thinking blocks entirely. """ if keep: if '' in text and '' not in text: text = 'Reasoning:\n' + text.replace('', '\n\nAnswer:') else: text = text.replace('', 'Reasoning:\n').replace('', '\n\nAnswer:') else: while '' in text: start = text.find('') end = text.find('') if start != -1 and start < end: text = text[:start] + text[end + 8:] else: text = text[end + 8:] return text def get_keep_prefill(): """Check if prefill should be kept in output, with per-request override support.""" overrides = _get_overrides() if overrides.get('keep_prefill') is not None: return overrides['keep_prefill'] return shared.opts.caption_vlm_keep_prefill def get_kwargs(): """Build generation kwargs from settings with per-request overrides from VQA instance. Checks the singleton VQA instance's generation_overrides for per-request overrides. Override keys: max_tokens, temperature, top_k, top_p, num_beams, do_sample None values are ignored, allowing selective override. """ # Get overrides from VQA singleton if available overrides = _get_overrides() # Get base values from settings, apply overrides if provided max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.caption_vlm_max_length do_sample = overrides.get('do_sample') if overrides.get('do_sample') is not None else shared.opts.caption_vlm_do_sample num_beams = overrides.get('num_beams') if overrides.get('num_beams') is not None else shared.opts.caption_vlm_num_beams temperature = overrides.get('temperature') if overrides.get('temperature') is not None else shared.opts.caption_vlm_temperature top_k = overrides.get('top_k') if overrides.get('top_k') is not None else shared.opts.caption_vlm_top_k top_p = overrides.get('top_p') if overrides.get('top_p') is not None else shared.opts.caption_vlm_top_p kwargs = { 'max_new_tokens': max_tokens, 'do_sample': do_sample, } if num_beams > 0: kwargs['num_beams'] = num_beams if temperature > 0: kwargs['temperature'] = temperature if top_k > 0: kwargs['top_k'] = top_k if top_p > 0: kwargs['top_p'] = top_p return kwargs class VQA: """Vision-Language Model interrogation class with per-model self-contained loading.""" def __init__(self): self.processor = None self.model = None self.loaded: str = None self.last_annotated_image = None self.last_detection_data = None self._generation_overrides = None # Per-request generation parameter overrides @property def generation_overrides(self): """Get current per-request generation parameter overrides.""" return self._generation_overrides def unload(self): """Release VLM model from GPU/memory, including external handlers.""" if self.model is not None: model_name = self.loaded log.debug(f'VQA unload: unloading model="{model_name}"') sd_models.move_model(self.model, devices.cpu, force=True) self.model = None self.processor = None self.loaded = None devices.torch_gc(force=True, reason='vqa unload') log.debug(f'VQA unload: model="{model_name}" unloaded') else: log.debug('VQA unload: no internal model loaded') # external handlers manage their own module-level globals and are not covered by self.model from modules.caption import moondream3, joycaption, joytag, deepseek moondream3.unload() joycaption.unload() joytag.unload() deepseek.unload() def _unload_current(self): """Free current model memory before loading a new one.""" if self.model is not None: deregister_aux('vqa') sd_models.move_model(self.model, devices.cpu, force=True) self.model = None self.processor = None devices.torch_gc(force=True, reason='vqa model switch') def load(self, model_name: str = None): """Load VLM model into memory for the specified model name.""" model_name = model_name or shared.opts.caption_vlm_model if not model_name: log.warning('VQA load: no model specified') return repo = get_vlm_repo(model_name) if repo == model_name and model_name not in vlm_models.values(): log.error(f'VQA load: unknown model="{model_name}"') return log.debug(f'VQA load: pre-loading model="{model_name}" repo="{repo}"') sd_models.set_caption_load_options() try: # dispatch to appropriate loader (same logic as caption) repo_lower = repo.lower() if 'mistral' in repo_lower: self._load_mistral(repo) elif 'qwen' in repo_lower or 'torii' in repo_lower or 'mimo' in repo_lower: self._load_qwen(repo) elif 'gemma' in repo_lower and 'pali' not in repo_lower: self._load_gemma(repo) elif 'smol' in repo_lower: self._load_smol(repo) elif 'florence' in repo_lower: self._load_florence(repo) elif 'moondream2' in repo_lower: self._load_moondream(repo) elif 'git' in repo_lower: self._load_git(repo) elif 'blip' in repo_lower: self._load_blip(repo) elif 'vilt' in repo_lower: self._load_vilt(repo) elif 'pix' in repo_lower: self._load_pix(repo) elif 'paligemma' in repo_lower: self._load_paligemma(repo) elif 'ovis' in repo_lower: self._load_ovis(repo) elif 'sa2' in repo_lower: self._load_sa2(repo) elif 'fastvlm' in repo_lower: self._load_fastvlm(repo) elif 'moondream3' in repo_lower: from modules.caption import moondream3 moondream3.load_model(repo) log.info(f'VQA load: model="{model_name}" loaded (external handler)') return elif 'joytag' in repo_lower: from modules.caption import joytag joytag.load() log.info(f'VQA load: model="{model_name}" loaded (external handler)') return elif 'joycaption' in repo_lower: from modules.caption import joycaption joycaption.load(repo) log.info(f'VQA load: model="{model_name}" loaded (external handler)') return elif 'deepseek' in repo_lower: from modules.caption import deepseek deepseek.load(repo) log.info(f'VQA load: model="{model_name}" loaded (external handler)') return else: log.warning(f'VQA load: no pre-loader for model="{model_name}"') return move_aux_to_gpu('vqa') log.info(f'VQA load: model="{model_name}" loaded') finally: sd_models.set_huggingface_options(quiet=True) def _load_fastvlm(self, repo: str): """Load FastVLM model and tokenizer.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() quant_args = model_quant.create_config(module='LLM') self.processor = transformers.AutoTokenizer.from_pretrained(repo, trust_remote_code=True, cache_dir=shared.opts.hfcache_dir) self.model = transformers.AutoModelForCausalLM.from_pretrained( repo, torch_dtype=devices.dtype, trust_remote_code=True, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, **quant_args, ) self.model.eval() register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _fastvlm(self, question: str, image: Image.Image, repo: str, model_name: str = None): debug(f'VQA caption: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}') self._load_fastvlm(repo) move_aux_to_gpu('vqa') if len(question) < 2: question = "Describe the image." question = question.replace('<', '').replace('>', '') IMAGE_TOKEN_INDEX = -200 # what the model code looks for messages = [{"role": "user", "content": f"\n{question}"}] rendered = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) pre, post = rendered.split("", 1) pre_ids = self.processor(pre, return_tensors="pt", add_special_tokens=False).input_ids post_ids = self.processor(post, return_tensors="pt", add_special_tokens=False).input_ids img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1) input_ids = input_ids.to(devices.device) attention_mask = torch.ones_like(input_ids, device=devices.device) px = self.model.get_vision_tower().image_processor(images=image, return_tensors="pt") px = px["pixel_values"].to(self.model.device, dtype=self.model.dtype) with devices.inference_context(): outputs = self.model.generate( inputs=input_ids, attention_mask=attention_mask, images=px, max_new_tokens=128, ) answer = self.processor.decode(outputs[0], skip_special_tokens=True) return answer # Map Qwen VL config model_type strings to their model classes. _QWEN_VL_MODEL_TYPE_MAP = { 'qwen3_5': 'Qwen3_5ForConditionalGeneration', 'qwen3_5_moe': 'Qwen3_5MoeForConditionalGeneration', 'qwen3_vl': 'Qwen3VLForConditionalGeneration', 'qwen2_5_vl': 'Qwen2_5_VLForConditionalGeneration', 'qwen2_vl': 'Qwen2VLForConditionalGeneration', } def _load_qwen(self, repo: str): """Load Qwen VL model and processor.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() if 'Qwen3.5' in repo and re.search(r'-A\d+B', repo): cls_name = transformers.Qwen3_5MoeForConditionalGeneration elif 'Qwen3.5' in repo: cls_name = transformers.Qwen3_5ForConditionalGeneration elif 'Qwen3-VL' in repo or 'Qwen3VL' in repo: cls_name = transformers.Qwen3VLForConditionalGeneration elif 'Qwen2.5-VL' in repo or 'Qwen2_5_VL' in repo or 'MiMo-VL' in repo: cls_name = transformers.Qwen2_5_VLForConditionalGeneration elif 'Qwen2-VL' in repo or 'Qwen2VL' in repo: cls_name = transformers.Qwen2VLForConditionalGeneration else: # Fine-tunes (e.g. ToriiGate) may not have "Qwen" in the repo name. # Detect the correct class from the config's model_type. config = transformers.AutoConfig.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) model_type = getattr(config, 'model_type', '') cls_attr = self._QWEN_VL_MODEL_TYPE_MAP.get(model_type) cls_name = getattr(transformers, cls_attr) if cls_attr else transformers.AutoModelForCausalLM quant_args = model_quant.create_config(module='LLM') self.model = cls_name.from_pretrained( repo, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, **quant_args, ) self.model.eval() self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir) if 'LLM' in shared.opts.cuda_compile: self.model = sd_models_compile.compile_torch(self.model, apply_to_components=False, op="VQA") register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _qwen(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False): self._load_qwen(repo) move_aux_to_gpu('vqa') # Get model class name for logging cls_name = self.model.__class__.__name__ debug(f'VQA caption: handler=qwen model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}') question = question.replace('<', '').replace('>', '').replace('_', ' ') system_prompt = system_prompt or shared.opts.caption_vlm_system conversation = [ { "role": "system", "content": [{"type": "text", "text": system_prompt}], }, { "role": "user", "content": [ {"type": "image", "image": b64(image)}, {"type": "text", "text": question}, ], } ] # Add prefill if provided prefill_value = vlm_prefill if prefill is None else prefill prefill_text = prefill_value.strip() # Thinking models emit their own tags via the chat template # Only models with thinking capability can use thinking mode is_thinking = is_thinking_model(model_name) # Standardize prefill prefill_value = vlm_prefill if prefill is None else prefill prefill_text = prefill_value.strip() use_prefill = len(prefill_text) > 0 if debug_enabled: debug(f'VQA caption: handler=qwen conversation_roles={[msg["role"] for msg in conversation]}') debug(f'VQA caption: handler=qwen full_conversation={truncate_b64_in_conversation(conversation)}') debug(f'VQA caption: handler=qwen is_thinking={is_thinking} thinking_mode={thinking_mode} prefill="{prefill_text}"') # Qwen3.5 uses native enable_thinking parameter in the chat template is_qwen35 = 'qwen3.5' in (model_name or '').lower() or 'qwen3.5' in repo.lower() template_kwargs = {'enable_thinking': thinking_mode} if is_qwen35 else {} # Generate base prompt using template # Qwen-Thinking template automatically adds "<|im_start|>assistant\n\n" when add_generation_prompt=True try: text_prompt = self.processor.apply_chat_template( conversation, add_generation_prompt=True, **template_kwargs, ) except (TypeError, ValueError) as e: debug(f'VQA caption: handler=qwen chat_template fallback add_generation_prompt=True: {e}') text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) # Manual think handling - skip for Qwen3.5 (template handles it natively) if is_thinking and not is_qwen35: if not thinking_mode: # User wants to SKIP thinking. # Since template opened the block with , we close it immediately. text_prompt += "\n" if use_prefill: text_prompt += prefill_text else: # User wants thinking. Prompt already ends in . # If prefill is provided, it becomes part of the thought process. if use_prefill: text_prompt += prefill_text else: # Standard model or Qwen3.5 (no manual manipulation needed) if use_prefill: text_prompt += prefill_text if debug_enabled: debug(f'VQA caption: handler=qwen text_prompt="{text_prompt}"') inputs = self.processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt") inputs = inputs.to(devices.device, devices.dtype) gen_kwargs = get_kwargs() debug(f'VQA caption: handler=qwen generation_kwargs={gen_kwargs} input_ids_shape={inputs.input_ids.shape}') with devices.inference_context(): output_ids = self.model.generate( **inputs, **gen_kwargs, ) debug(f'VQA caption: handler=qwen output_ids_shape={output_ids.shape}') generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids, strict=False) ] response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) if debug_enabled: debug(f'VQA caption: handler=qwen response_before_clean="{response}"') if len(response) > 0: response[0] = strip_think_xml_tags(response[0], keep=get_keep_thinking()) return response def _load_gemma(self, repo: str): """Load Gemma 3 model and processor.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() if '3n' in repo: cls = transformers.Gemma3nForConditionalGeneration # pylint: disable=no-member else: cls = transformers.Gemma3ForConditionalGeneration quant_args = model_quant.create_config(module='LLM') self.model = cls.from_pretrained( repo, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, **quant_args, ) self.model.eval() if 'LLM' in shared.opts.cuda_compile: self.model = sd_models_compile.compile_torch(self.model, apply_to_components=False, op="VQA") self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir) register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _gemma(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False): self._load_gemma(repo) move_aux_to_gpu('vqa') # Get model class name for logging cls_name = self.model.__class__.__name__ debug(f'VQA caption: handler=gemma model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}') question = question.replace('<', '').replace('>', '').replace('_', ' ') system_prompt = system_prompt or shared.opts.caption_vlm_system system_content = [] if system_prompt is not None and len(system_prompt) > 4: system_content.append({"type": "text", "text": system_prompt}) user_content = [] if question is not None and len(question) > 4: user_content.append({"type": "text", "text": question}) if image is not None: user_content.append({"type": "image", "image": b64(image)}) conversation = [ {"role": "system", "content": system_content}, {"role": "user", "content": user_content}, ] # Add prefill if provided prefill_value = vlm_prefill if prefill is None else prefill prefill_text = prefill_value.strip() use_prefill = len(prefill_text) > 0 # Thinking models emit their own tags via the chat template # Use manual toggle OR auto-detection based on model name use_thinking = thinking_mode or is_thinking_model(model_name) if use_prefill: conversation.append({ "role": "assistant", "content": [{"type": "text", "text": prefill_text}], }) debug(f'VQA caption: handler=gemma prefill="{prefill_text}"') else: debug('VQA caption: handler=gemma prefill disabled (empty), relying on add_generation_prompt') if debug_enabled: debug(f'VQA caption: handler=gemma conversation_roles={[msg["role"] for msg in conversation]}') debug(f'VQA caption: handler=gemma full_conversation={truncate_b64_in_conversation(conversation)}') debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True' debug(f'VQA caption: handler=gemma template_mode={debug_prefill_mode}') try: if use_prefill: text_prompt = self.processor.apply_chat_template( conversation, add_generation_prompt=False, continue_final_message=True, tokenize=False, ) else: text_prompt = self.processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=False, ) except (TypeError, ValueError) as e: debug(f'VQA caption: handler=gemma chat_template fallback add_generation_prompt=True: {e}') text_prompt = self.processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=False, ) if use_prefill and use_thinking: text_prompt = keep_think_block_open(text_prompt) if debug_enabled: debug(f'VQA caption: handler=gemma text_prompt="{text_prompt}"') inputs = self.processor( text=[text_prompt], images=[image], padding=True, return_tensors="pt", ).to(device=devices.device, dtype=devices.dtype) input_len = inputs["input_ids"].shape[-1] gen_kwargs = get_kwargs() debug(f'VQA caption: handler=gemma generation_kwargs={gen_kwargs} input_len={input_len}') with devices.inference_context(): generation = self.model.generate( **inputs, **gen_kwargs, ) debug(f'VQA caption: handler=gemma output_ids_shape={generation.shape}') generation = generation[0][input_len:] response = self.processor.decode(generation, skip_special_tokens=True) if debug_enabled: debug(f'VQA caption: handler=gemma response_before_clean="{response}"') response = strip_think_xml_tags(response, keep=get_keep_thinking()) return response def _load_mistral(self, repo: str): """Load Mistral3 vision model and processor.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() quant_args = model_quant.create_config(module='LLM') self.model = transformers.Mistral3ForConditionalGeneration.from_pretrained( repo, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, **quant_args, ) self.model.eval() if 'LLM' in shared.opts.cuda_compile: self.model = sd_models_compile.compile_torch(self.model, apply_to_components=False, op="VQA") self.processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _mistral(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False): self._load_mistral(repo) move_aux_to_gpu('vqa') cls_name = self.model.__class__.__name__ debug(f'VQA caption: handler=mistral model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}') question = question.replace('<', '').replace('>', '').replace('_', ' ') system_prompt = system_prompt or shared.opts.caption_vlm_system conversation = [] if system_prompt and len(system_prompt) > 4: conversation.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) user_content = [] if image is not None: user_content.append({"type": "image", "image": b64(image)}) if question and len(question) > 1: user_content.append({"type": "text", "text": question}) conversation.append({"role": "user", "content": user_content}) prefill_value = vlm_prefill if prefill is None else prefill prefill_text = prefill_value.strip() use_prefill = len(prefill_text) > 0 if use_prefill: conversation.append({"role": "assistant", "content": [{"type": "text", "text": prefill_text}]}) if debug_enabled: debug(f'VQA caption: handler=mistral conversation_roles={[msg["role"] for msg in conversation]}') debug(f'VQA caption: handler=mistral full_conversation={truncate_b64_in_conversation(conversation)}') try: if use_prefill: text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=False, continue_final_message=True, tokenize=False) else: text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) except (TypeError, ValueError) as e: debug(f'VQA caption: handler=mistral chat_template fallback: {e}') text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) if debug_enabled: debug(f'VQA caption: handler=mistral text_prompt="{text_prompt}"') inputs = self.processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to(device=devices.device, dtype=devices.dtype) input_len = inputs["input_ids"].shape[-1] gen_kwargs = get_kwargs() debug(f'VQA caption: handler=mistral generation_kwargs={gen_kwargs} input_len={input_len}') with devices.inference_context(): generation = self.model.generate(**inputs, **gen_kwargs) generation = generation[0][input_len:] response = self.processor.decode(generation, skip_special_tokens=True) if debug_enabled: debug(f'VQA caption: handler=mistral response_before_clean="{response}"') return response def _load_paligemma(self, repo: str): """Load PaliGemma model and processor.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() self.processor = transformers.PaliGemmaProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) self.model = transformers.PaliGemmaForConditionalGeneration.from_pretrained( repo, cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, ) self.model.eval() register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _paligemma(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument self._load_paligemma(repo) move_aux_to_gpu('vqa') question = question.replace('<', '').replace('>', '').replace('_', ' ') model_inputs = self.processor(text=question, images=image, return_tensors="pt").to(devices.device, devices.dtype) input_len = model_inputs["input_ids"].shape[-1] with devices.inference_context(): generation = self.model.generate( **model_inputs, **get_kwargs(), ) generation = generation[0][input_len:] response = self.processor.decode(generation, skip_special_tokens=True) return response def _load_ovis(self, repo: str): """Load Ovis model (requires flash-attn).""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() # Ovis remote code calls AutoConfig.register("aimv2", ...) at module scope # without exist_ok=True, which fails on reload or when the type is already # registered by a newer transformers version. _orig = transformers.AutoConfig.register.__func__ if hasattr(transformers.AutoConfig.register, '__func__') else transformers.AutoConfig.register transformers.AutoConfig.register = staticmethod(lambda model_type, config, exist_ok=False: _orig(model_type, config, exist_ok=True)) try: self.model = transformers.AutoModelForCausalLM.from_pretrained( repo, torch_dtype=devices.dtype, multimodal_max_length=32768, trust_remote_code=True, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, ) finally: transformers.AutoConfig.register = _orig self.model.eval() register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument try: pass # pylint: disable=unused-import except Exception: log.error(f'Caption: vlm="{repo}" flash-attn is not available') return '' self._load_ovis(repo) move_aux_to_gpu('vqa') text_tokenizer = self.model.get_text_tokenizer() visual_tokenizer = self.model.get_visual_tokenizer() max_partition = 9 question = question.replace('<', '').replace('>', '').replace('_', ' ') question = f'\n{question}' _prompt, input_ids, pixel_values = self.model.preprocess_inputs(question, [image], max_partition=max_partition) attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id) input_ids = input_ids.unsqueeze(0).to(device=self.model.device) attention_mask = attention_mask.unsqueeze(0).to(device=self.model.device) if pixel_values is not None: pixel_values = pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device) pixel_values = [pixel_values] with devices.inference_context(): output_ids = self.model.generate( input_ids, pixel_values=pixel_values, attention_mask=attention_mask, repetition_penalty=None, eos_token_id=self.model.generation_config.eos_token_id, pad_token_id=text_tokenizer.pad_token_id, use_cache=True, **get_kwargs()) response = text_tokenizer.decode(output_ids[0], skip_special_tokens=True) return response def _load_smol(self, repo: str): """Load SmolVLM model and processor.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() quant_args = model_quant.create_config(module='LLM') self.model = transformers.AutoModelForVision2Seq.from_pretrained( repo, cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, **quant_args, ) self.model.eval() self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir) if 'LLM' in shared.opts.cuda_compile: self.model = sd_models_compile.compile_torch(self.model, apply_to_components=False, op="VQA") register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _smol(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False): self._load_smol(repo) move_aux_to_gpu('vqa') # Get model class name for logging cls_name = self.model.__class__.__name__ debug(f'VQA caption: handler=smol model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}') question = question.replace('<', '').replace('>', '').replace('_', ' ') system_prompt = system_prompt or shared.opts.caption_vlm_system conversation = [ { "role": "system", "content": [{"type": "text", "text": system_prompt}], }, { "role": "user", "content": [ {"type": "image", "image": b64(image)}, {"type": "text", "text": question}, ], } ] # Add prefill if provided prefill_value = vlm_prefill if prefill is None else prefill prefill_text = prefill_value.strip() use_prefill = len(prefill_text) > 0 # Thinking models emit their own tags via the chat template # Use manual toggle OR auto-detection based on model name use_thinking = thinking_mode or is_thinking_model(model_name) if use_prefill: conversation.append({ "role": "assistant", "content": [{"type": "text", "text": prefill_text}], }) debug(f'VQA caption: handler=smol prefill="{prefill_text}"') else: debug('VQA caption: handler=smol prefill disabled (empty), relying on add_generation_prompt') if debug_enabled: debug(f'VQA caption: handler=smol conversation_roles={[msg["role"] for msg in conversation]}') debug(f'VQA caption: handler=smol full_conversation={truncate_b64_in_conversation(conversation)}') debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True' debug(f'VQA caption: handler=smol template_mode={debug_prefill_mode}') try: if use_prefill: text_prompt = self.processor.apply_chat_template( conversation, add_generation_prompt=False, continue_final_message=True, ) else: text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) except (TypeError, ValueError) as e: debug(f'VQA caption: handler=smol chat_template fallback add_generation_prompt=True: {e}') text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) if use_prefill and use_thinking: text_prompt = keep_think_block_open(text_prompt) if debug_enabled: debug(f'VQA caption: handler=smol text_prompt="{text_prompt}"') inputs = self.processor(text=text_prompt, images=[image], padding=True, return_tensors="pt") inputs = inputs.to(devices.device, devices.dtype) gen_kwargs = get_kwargs() debug(f'VQA caption: handler=smol generation_kwargs={gen_kwargs}') with devices.inference_context(): output_ids = self.model.generate( **inputs, **gen_kwargs, ) debug(f'VQA caption: handler=smol output_ids_shape={output_ids.shape}') response = self.processor.batch_decode(output_ids, skip_special_tokens=True) if debug_enabled: debug(f'VQA caption: handler=smol response_before_clean="{response}"') if len(response) > 0: response[0] = strip_think_xml_tags(response[0], keep=get_keep_thinking()) return response def _load_git(self, repo: str): """Load Microsoft GIT model and processor.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() self.model = transformers.GitForCausalLM.from_pretrained( repo, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, ) self.model.eval() self.processor = transformers.GitProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _git(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument self._load_git(repo) move_aux_to_gpu('vqa') pixel_values = self.processor(images=image, return_tensors="pt").pixel_values git_dict = {} git_dict['pixel_values'] = pixel_values.to(devices.device, devices.dtype) if len(question) > 0: input_ids = self.processor(text=question, add_special_tokens=False).input_ids input_ids = [self.processor.tokenizer.cls_token_id] + input_ids input_ids = torch.tensor(input_ids).unsqueeze(0) git_dict['input_ids'] = input_ids.to(devices.device) with devices.inference_context(): generated_ids = self.model.generate(**git_dict) response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return response def _load_blip(self, repo: str): """Load Salesforce BLIP model and processor.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() self.model = transformers.BlipForQuestionAnswering.from_pretrained( repo, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, ) self.model.eval() self.processor = transformers.BlipProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _blip(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument self._load_blip(repo) move_aux_to_gpu('vqa') inputs = self.processor(image, question, return_tensors="pt") inputs = inputs.to(devices.device, devices.dtype) with devices.inference_context(): outputs = self.model.generate(**inputs) response = self.processor.decode(outputs[0], skip_special_tokens=True) return response def _load_vilt(self, repo: str): """Load ViLT model and processor.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() self.model = transformers.ViltForQuestionAnswering.from_pretrained( repo, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, ) self.model.eval() self.processor = transformers.ViltProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _vilt(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument self._load_vilt(repo) move_aux_to_gpu('vqa') inputs = self.processor(image, question, return_tensors="pt") inputs = inputs.to(devices.device) with devices.inference_context(): outputs = self.model(**inputs) logits = outputs.logits idx = logits.argmax(-1).item() response = self.model.config.id2label[idx] return response def _load_pix(self, repo: str): """Load Pix2Struct model and processor.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() self.model = transformers.Pix2StructForConditionalGeneration.from_pretrained( repo, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, ) self.model.eval() self.processor = transformers.Pix2StructProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _pix(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument self._load_pix(repo) move_aux_to_gpu('vqa') if len(question) > 0: inputs = self.processor(images=image, text=question, return_tensors="pt") else: inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(devices.device, devices.dtype) if v.is_floating_point() else v.to(devices.device) for k, v in inputs.items()} with devices.inference_context(): outputs = self.model.generate(**inputs) response = self.processor.decode(outputs[0], skip_special_tokens=True) return response def _load_moondream(self, repo: str): """Load Moondream 2 model and tokenizer.""" if self.model is None or self.loaded != repo: log.debug(f'Caption load: vlm="{repo}"') self._unload_current() self.model = transformers.AutoModelForCausalLM.from_pretrained( repo, revision="2025-06-21", trust_remote_code=True, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, ) self.processor = transformers.AutoTokenizer.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) self.loaded = repo self.model.eval() # required: trust_remote_code model register_aux('vqa', self.model) devices.torch_gc() def _moondream(self, question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False): debug(f'VQA caption: handler=moondream model_name="{model_name}" repo="{repo}" question="{question}" thinking_mode={thinking_mode}') self._load_moondream(repo) move_aux_to_gpu('vqa') question = question.replace('<', '').replace('>', '').replace('_', ' ') with devices.inference_context(): if question == 'CAPTION': response = self.model.caption(image, length="short")['caption'] elif question == 'DETAILED CAPTION': response = self.model.caption(image, length="normal")['caption'] elif question == 'MORE DETAILED CAPTION': response = self.model.caption(image, length="long")['caption'] elif question.lower().startswith('point at ') or question == 'POINT_MODE': target = question[9:].strip() if question.lower().startswith('point at ') else '' if not target: return "Please specify an object to locate" debug(f'VQA caption: handler=moondream method=point target="{target}"') result = self.model.point(image, target) debug(f'VQA caption: handler=moondream point_raw_result={result}') points = vqa_detection.parse_points(result) if points: self.last_detection_data = {'points': points} return vqa_detection.format_points_text(points) return "Object not found" elif question == 'DETECT_GAZE' or question.lower() == 'detect gaze': # Must be checked before generic 'detect ' prefix to avoid matching as detect target="Gaze" debug('VQA caption: handler=moondream method=detect_gaze') faces = self.model.detect(image, "face") debug(f'VQA caption: handler=moondream detect_gaze faces={faces}') if faces.get('objects'): eye_x, eye_y = vqa_detection.calculate_eye_position(faces['objects'][0]) result = self.model.detect_gaze(image, eye=(eye_x, eye_y)) debug(f'VQA caption: handler=moondream detect_gaze result={result}') if result.get('gaze'): gaze = result['gaze'] self.last_detection_data = {'points': [(gaze['x'], gaze['y'])]} return f"Gaze direction: ({gaze['x']:.3f}, {gaze['y']:.3f})" return "No face/gaze detected" elif question.lower().startswith('detect ') or question == 'DETECT_MODE': target = question[7:].strip() if question.lower().startswith('detect ') else '' if not target: return "Please specify an object to detect" debug(f'VQA caption: handler=moondream method=detect target="{target}"') result = self.model.detect(image, target) debug(f'VQA caption: handler=moondream detect_raw_result={result}') detections = vqa_detection.parse_detections(result, target) if detections: self.last_detection_data = {'detections': detections} return vqa_detection.format_detections_text(detections, include_confidence=False) return "No objects detected" else: debug(f'VQA caption: handler=moondream method=query question="{question}" reasoning={thinking_mode}') result = self.model.query(image, question, reasoning=thinking_mode) response = result['answer'] debug(f'VQA caption: handler=moondream query_result keys={list(result.keys()) if isinstance(result, dict) else "not dict"}') if thinking_mode and 'reasoning' in result: reasoning_text = result['reasoning'].get('text', '') if isinstance(result['reasoning'], dict) else str(result['reasoning']) debug(f'VQA caption: handler=moondream reasoning_text="{reasoning_text[:100]}..."') if get_keep_thinking(): response = f"Reasoning:\n{reasoning_text}\n\nAnswer:\n{response}" # When keep_thinking is False, just use the answer (reasoning is discarded) return response def _load_florence(self, repo: str, revision: str = None): """Load Florence-2 model and processor.""" _get_imports = transformers.dynamic_module_utils.get_imports def get_imports(f): R = _get_imports(f) if "flash_attn" in R: R.remove("flash_attn") # flash_attn is optional return R # Handle revision splitting and caching cache_key = repo effective_revision = revision repo_name = repo if repo and '@' in repo: repo_name, revision_from_repo = repo.split('@') effective_revision = revision_from_repo if self.model is None or self.loaded != cache_key: log.debug(f'Caption load: vlm="{repo_name}" revision="{effective_revision}" path="{shared.opts.hfcache_dir}"') self._unload_current() transformers.dynamic_module_utils.get_imports = get_imports quant_args = model_quant.create_config(module='LLM') self.model = transformers.Florence2ForConditionalGeneration.from_pretrained( repo_name, revision=effective_revision, torch_dtype=devices.dtype, use_safetensors=True, low_cpu_mem_usage=True, cache_dir=shared.opts.hfcache_dir, **quant_args, ) self.model.eval() self.processor = transformers.AutoProcessor.from_pretrained(repo_name, max_pixels=1024*1024, trust_remote_code=True, revision=effective_revision, cache_dir=shared.opts.hfcache_dir) transformers.dynamic_module_utils.get_imports = _get_imports register_aux('vqa', self.model) self.loaded = cache_key devices.torch_gc() def _florence(self, question: str, image: Image.Image, repo: str, revision: str = None, model_name: str = None): # pylint: disable=unused-argument self._load_florence(repo, revision) move_aux_to_gpu('vqa') if question.startswith('<'): task = question.split('>', 1)[0] + '>' else: task = '' debug(f'VQA caption: handler=florence model_name="{model_name}" repo="{repo}" task="{task}" question="{question}" image_size={image.size}') inputs = self.processor(text=task, images=image, return_tensors="pt") input_ids = inputs['input_ids'].to(devices.device) pixel_values = inputs['pixel_values'].to(devices.device, devices.dtype) debug(f'VQA caption: handler=florence input_ids={input_ids.shape} pixel_values={pixel_values.shape} dtype={pixel_values.dtype}') # Florence-2 requires beam search, not sampling - sampling causes probability tensor errors overrides = _get_overrides() max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.caption_vlm_max_length gen_kwargs = {'max_new_tokens': max_tokens, 'num_beams': 3, 'do_sample': False} # Some Florence fine-tunes (e.g., CogFlorence) don't have decoder_start_token_id set if getattr(self.model.config, 'decoder_start_token_id', None) is None: bos_token_id = getattr(self.processor.tokenizer, 'bos_token_id', None) or 0 gen_kwargs['decoder_start_token_id'] = bos_token_id debug(f'VQA caption: handler=florence setting decoder_start_token_id={bos_token_id}') debug(f'VQA caption: handler=florence generation_kwargs={gen_kwargs}') with devices.inference_context(), devices.bypass_sdpa_hijacks(): generated_ids = self.model.generate( input_ids=input_ids, pixel_values=pixel_values, **gen_kwargs, ) generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0] debug(f'VQA caption: handler=florence generated_text="{generated_text}"') # task="task" is intentional: produces {'task': text} which both parse_florence_detections and # format_florence_response handle via explicit 'task' key fallbacks, avoiding task-token-specific keys response = self.processor.post_process_generation(generated_text, task="task", image_size=(image.width, image.height)) debug(f'VQA caption: handler=florence raw_response={response}') return response def _load_sa2(self, repo: str): """Load SA2VA model and tokenizer.""" if self.model is None or self.loaded != repo: self._unload_current() self.model = transformers.AutoModel.from_pretrained( repo, torch_dtype=devices.dtype, low_cpu_mem_usage=True, use_flash_attn=False, use_safetensors=True, trust_remote_code=True, cache_dir=shared.opts.hfcache_dir) self.model = self.model.eval() # required: trust_remote_code model self.processor = transformers.AutoTokenizer.from_pretrained( repo, trust_remote_code=True, use_fast=False, cache_dir=shared.opts.hfcache_dir, ) register_aux('vqa', self.model) self.loaded = repo devices.torch_gc() def _sa2(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument self._load_sa2(repo) move_aux_to_gpu('vqa') if question.startswith('<'): task = question.split('>', 1)[0] + '>' else: task = '' input_dict = { 'image': image, 'text': f'{task}', 'past_text': '', 'mask_prompts': None, 'tokenizer': self.processor, } with devices.inference_context(): return_dict = self.model.predict_forward(**input_dict) response = return_dict["prediction"] # the text format answer return response def caption(self, question: str = '', system_prompt: str = None, prompt: str = None, image: Image.Image = None, model_name: str = None, prefill: str = None, thinking_mode: bool = None, quiet: bool = False, generation_kwargs: dict = None) -> str: """ Main entry point for VQA captioning. Returns string answer. Detection data stored in self.last_detection_data for annotated image creation. Args: question: Question/task to perform system_prompt: System prompt for the model prompt: Additional prompt text image: PIL Image to process model_name: Model to use (defaults to settings) prefill: Text to prefill the response with thinking_mode: Enable thinking/reasoning mode (None = use settings) quiet: Suppress logging generation_kwargs: Optional dict with generation parameter overrides: max_tokens, temperature, top_k, top_p, num_beams, do_sample, keep_thinking, keep_prefill """ self.last_annotated_image = None self.last_detection_data = None self._generation_overrides = generation_kwargs # Set per-request overrides jobid = shared.state.begin('Caption LLM') t0 = time.time() model_name = model_name or shared.opts.caption_vlm_model prefill = vlm_prefill if prefill is None else prefill # Use provided prefill when specified thinking_mode = shared.opts.caption_vlm_thinking_mode if thinking_mode is None else thinking_mode # Resolve from settings if not specified if isinstance(image, list): image = image[0] if len(image) > 0 else None if isinstance(image, dict) and 'name' in image: image = Image.open(image['name']) if isinstance(image, Image.Image): if image.width > 768 or image.height > 768: image.thumbnail((768, 768), Image.Resampling.LANCZOS) if image.mode != 'RGB': image = image.convert('RGB') if image is None: log.error(f'VQA caption: model="{model_name}" error="No input image provided"') self._generation_overrides = None shared.state.end(jobid) return 'Error: No input image provided. Please upload or select an image.' # Convert friendly prompt names to internal tokens/commands if question == "Use Prompt": # Use content from Prompt field directly - requires user input if not prompt or len(prompt.strip()) < 2: log.error(f'VQA caption: model="{model_name}" error="Please enter a prompt"') self._generation_overrides = None shared.state.end(jobid) return 'Error: Please enter a question or instruction in the Prompt field.' question = prompt elif question in vlm_prompt_mapping: # Check if this is a mode that requires user input (Point/Detect) raw_mapping = vlm_prompt_mapping.get(question) if raw_mapping in ("POINT_MODE", "DETECT_MODE"): # These modes require user input in the prompt field if not prompt or len(prompt.strip()) < 2: log.error(f'VQA caption: model="{model_name}" error="Please specify what to find in the prompt field"') self._generation_overrides = None shared.state.end(jobid) return 'Error: Please specify what to find in the prompt field (e.g., "the red car" or "faces").' # Convert friendly name to internal token (handles Point/Detect prefix) question = get_internal_prompt(question, prompt) # else: question is already an internal token or custom text from modules import modelloader modelloader.hf_login() sd_models.set_caption_load_options() try: if model_name is None: log.error(f'Caption: type=vlm model="{model_name}" no model selected') shared.state.end(jobid) return '' vqa_model = get_vlm_repo(model_name) if vqa_model == model_name and model_name not in vlm_models.values(): log.error(f'Caption: type=vlm model="{model_name}" unknown') shared.state.end(jobid) return '' handler = 'unknown' if 'git' in vqa_model.lower(): handler = 'git' answer = self._git(question, image, vqa_model, model_name) elif 'vilt' in vqa_model.lower(): handler = 'vilt' answer = self._vilt(question, image, vqa_model, model_name) elif 'blip' in vqa_model.lower(): handler = 'blip' answer = self._blip(question, image, vqa_model, model_name) elif 'pix' in vqa_model.lower(): handler = 'pix' answer = self._pix(question, image, vqa_model, model_name) elif 'moondream3' in vqa_model.lower(): handler = 'moondream3' from modules.caption import moondream3 answer = moondream3.predict(question, image, vqa_model, model_name, thinking_mode=thinking_mode) elif 'moondream2' in vqa_model.lower(): handler = 'moondream' answer = self._moondream(question, image, vqa_model, model_name, thinking_mode) elif 'florence' in vqa_model.lower(): handler = 'florence' answer = self._florence(question, image, vqa_model, None, model_name) # Parse Florence detection response for annotated image (handles both dict and string formats) florence_detections = vqa_detection.parse_florence_detections(answer, image.size if image else None) if florence_detections: self.last_detection_data = {'detections': florence_detections} debug(f'VQA caption: handler=florence parsed {len(florence_detections)} detections') # Format dict answer as readable string (string answers pass through unchanged) if isinstance(answer, dict): answer = vqa_detection.format_florence_response(answer) elif 'mistral' in vqa_model.lower(): handler = 'mistral' answer = self._mistral(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode) elif 'qwen' in vqa_model.lower() or 'torii' in vqa_model.lower() or 'mimo' in vqa_model.lower(): handler = 'qwen' answer = self._qwen(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode) elif 'smol' in vqa_model.lower(): handler = 'smol' answer = self._smol(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode) elif 'joytag' in vqa_model.lower(): handler = 'joytag' from modules.caption import joytag answer = joytag.predict(image) elif 'joycaption' in vqa_model.lower(): handler = 'joycaption' from modules.caption import joycaption answer = joycaption.predict(question, image, vqa_model) elif 'deepseek' in vqa_model.lower(): handler = 'deepseek' from modules.caption import deepseek answer = deepseek.predict(question, image, vqa_model) elif 'paligemma' in vqa_model.lower(): handler = 'paligemma' answer = self._paligemma(question, image, vqa_model, model_name) elif 'gemma' in vqa_model.lower(): handler = 'gemma' answer = self._gemma(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode) elif 'ovis' in vqa_model.lower(): handler = 'ovis' answer = self._ovis(question, image, vqa_model, model_name) elif 'sa2' in vqa_model.lower(): handler = 'sa2' answer = self._sa2(question, image, vqa_model, model_name) elif 'fastvlm' in vqa_model.lower(): handler = 'fastvlm' answer = self._fastvlm(question, image, vqa_model, model_name) elif 'gemini' in vqa_model.lower(): handler = 'gemini' gen_kwargs = get_kwargs() from modules.caption import gemini answer = gemini.predict(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode, gen_kwargs) else: answer = 'unknown model' except Exception as e: errors.display(e, 'VQA') answer = 'error' finally: sd_models.set_huggingface_options(quiet=True) if self.model is not None: offload_aux('vqa') devices.torch_gc(force=True, reason='vqa') # Clean the answer answer = clean(answer, question, prefill) # Create annotated image if detection data is available if self.last_detection_data and isinstance(self.last_detection_data, dict) and image: detections = self.last_detection_data.get('detections', None) points = self.last_detection_data.get('points', None) if detections or points: self.last_annotated_image = vqa_detection.draw_bounding_boxes(image, detections or [], points) debug(f'VQA caption: handler={handler} created annotated image detections={len(detections) if detections else 0} points={len(points) if points else 0}') debug(f'VQA caption: handler={handler} response="{answer}" annotation={self.last_annotated_image is not None}') t1 = time.time() if not quiet: model_name = model_name.split(' ')[0] if model_name else 'None' log.debug(f'Caption: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}') self._generation_overrides = None # Clear per-request overrides shared.state.end(jobid) return answer def batch(self, model_name, system_prompt, batch_files, batch_folder, batch_str, question, prompt, write, append, recursive, prefill=None, thinking_mode=False): class BatchWriter: def __init__(self, folder, mode='w'): self.folder = folder self.csv = None self.file = None self.mode = mode def add(self, file, prompt_text): txt_file = os.path.splitext(file)[0] + ".txt" if self.mode == 'a': prompt_text = '\n' + prompt_text with open(os.path.join(self.folder, txt_file), self.mode, encoding='utf-8') as f: f.write(prompt_text) def close(self): if self.file is not None: self.file.close() files = [] if batch_files is not None: files += [f.name for f in batch_files] if batch_folder is not None: files += [f.name for f in batch_folder] if batch_str is not None and len(batch_str) > 0 and os.path.exists(batch_str) and os.path.isdir(batch_str): from modules.files_cache import list_files files += list(list_files(batch_str, ext_filter=['.png', '.jpg', '.jpeg', '.webp', '.jxl'], recursive=recursive)) if len(files) == 0: log.warning('Caption batch: type=vlm no images') return '' jobid = shared.state.begin('Caption batch') prompts = [] if write: mode = 'w' if not append else 'a' writer = BatchWriter(os.path.dirname(files[0]), mode=mode) orig_offload = shared.opts.caption_offload shared.opts.caption_offload = False try: import rich.progress as rp pbar = rp.Progress(rp.TextColumn('[cyan]Caption:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=console) with pbar: task = pbar.add_task(total=len(files), description='starting...') for file in files: pbar.update(task, advance=1, description=file) try: if shared.state.interrupted: break img = Image.open(file) result = self.caption(question, system_prompt, prompt, img, model_name, prefill, thinking_mode, quiet=True) # Save annotated image if available if self.last_annotated_image and write: annotated_path = os.path.splitext(file)[0] + "_annotated.png" self.last_annotated_image.save(annotated_path) prompts.append(result) if write: writer.add(file, result) except Exception as e: log.error(f'Caption batch: {e}') if write: writer.close() finally: shared.opts.caption_offload = orig_offload offload_aux('vqa') shared.state.end(jobid) return '\n\n'.join(prompts) # Module-level singleton instance _instance = None def get_instance() -> VQA: """Get or create the singleton VQA instance.""" global _instance # pylint: disable=global-statement if _instance is None: _instance = VQA() return _instance # Backwards-compatible module-level functions def caption(*args, **kwargs): return get_instance().caption(*args, **kwargs) def unload_model(): return get_instance().unload() def load_model(model_name: str = None): return get_instance().load(model_name) def get_last_annotated_image(): return get_instance().last_annotated_image def batch(*args, **kwargs): return get_instance().batch(*args, **kwargs)