mirror of https://github.com/vladmandic/automatic
1378 lines
58 KiB
Python
1378 lines
58 KiB
Python
import io
|
|
import os
|
|
import time
|
|
import json
|
|
import base64
|
|
import copy
|
|
import torch
|
|
import transformers
|
|
import transformers.dynamic_module_utils
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
from modules import shared, devices, errors, model_quant, sd_models, sd_models_compile, ui_symbols
|
|
|
|
|
|
# Debug logging - function-based to avoid circular import
|
|
debug_enabled = os.environ.get('SD_VQA_DEBUG', None) is not None
|
|
|
|
def debug(*args, **kwargs):
|
|
if debug_enabled:
|
|
shared.log.trace(*args, **kwargs)
|
|
|
|
processor = None
|
|
model = None
|
|
loaded: str = None
|
|
quant_args = None
|
|
vlm_default = "Alibaba Qwen 2.5 VL 3B"
|
|
vlm_models = {
|
|
"Google Gemma 3 4B": "google/gemma-3-4b-it",
|
|
"Google Gemma 3n E2B": "google/gemma-3n-E2B-it", # 1.5GB
|
|
"Google Gemma 3n E4B": "google/gemma-3n-E4B-it", # 1.5GB
|
|
"Nidum Gemma 3 4B Uncensored": "nidum/Nidum-Gemma-3-4B-it-Uncensored",
|
|
"Allura Gemma 3 Glitter 4B": "allura-org/Gemma-3-Glitter-4B",
|
|
"Alibaba Qwen 2.0 VL 2B": "Qwen/Qwen2-VL-2B-Instruct",
|
|
"Alibaba Qwen 2.5 Omni 3B": "Qwen/Qwen2.5-Omni-3B",
|
|
"Alibaba Qwen 2.5 VL 3B": "Qwen/Qwen2.5-VL-3B-Instruct",
|
|
"Alibaba Qwen 3 VL 2B": "Qwen/Qwen3-VL-2B-Instruct",
|
|
f"Alibaba Qwen 3 VL 2B Thinking {ui_symbols.reasoning}": "Qwen/Qwen3-VL-2B-Thinking",
|
|
"Alibaba Qwen 3 VL 4B": "Qwen/Qwen3-VL-4B-Instruct",
|
|
f"Alibaba Qwen 3 VL 4B Thinking {ui_symbols.reasoning}": "Qwen/Qwen3-VL-4B-Thinking",
|
|
"Alibaba Qwen 3 VL 8B": "Qwen/Qwen3-VL-8B-Instruct",
|
|
f"Alibaba Qwen 3 VL 8B Thinking {ui_symbols.reasoning}": "Qwen/Qwen3-VL-8B-Thinking",
|
|
"XiaomiMiMo MiMo VL 7B RL": "XiaomiMiMo/MiMo-VL-7B-RL-2508", # 8.3GB
|
|
"Huggingface Smol VL2 0.5B": "HuggingFaceTB/SmolVLM-500M-Instruct",
|
|
"Huggingface Smol VL2 2B": "HuggingFaceTB/SmolVLM-Instruct",
|
|
"Apple FastVLM 0.5B": "apple/FastVLM-0.5B",
|
|
"Apple FastVLM 1.5B": "apple/FastVLM-1.5B",
|
|
"Apple FastVLM 7B": "apple/FastVLM-7B",
|
|
"Microsoft Florence 2 Base": "florence-community/Florence-2-base-ft", # 0.5GB
|
|
"Microsoft Florence 2 Large": "florence-community/Florence-2-large-ft", # 1.5GB
|
|
"MiaoshouAI PromptGen 1.5 Base": "Disty0/Florence-2-base-PromptGen-v1.5", # 0.5GB
|
|
"MiaoshouAI PromptGen 1.5 Large": "Disty0/Florence-2-large-PromptGen-v1.5", # 1.5GB
|
|
"MiaoshouAI PromptGen 2.0 Base": "Disty0/Florence-2-base-PromptGen-v2.0", # 0.5GB
|
|
"MiaoshouAI PromptGen 2.0 Large": "Disty0/Florence-2-large-PromptGen-v2.0", # 1.5GB
|
|
"CogFlorence 2.0 Large": "thwri/CogFlorence-2-Large-Freeze", # 1.6GB
|
|
"CogFlorence 2.2 Large": "thwri/CogFlorence-2.2-Large", # 1.6GB
|
|
"Moondream 2": "vikhyatk/moondream2", # 3.7GB
|
|
"Moondream 3 Preview": "moondream/moondream3-preview", # 9.3GB (gated)
|
|
"Google Pix Textcaps": "google/pix2struct-textcaps-base", # 1.1GB
|
|
"Google PaliGemma 2 3B": "google/paligemma2-3b-pt-224",
|
|
"Salesforce BLIP Base": "Salesforce/blip-vqa-base", # 1.5GB
|
|
"Salesforce BLIP Large": "Salesforce/blip-vqa-capfilt-large", # 1.5GB
|
|
"Microsoft GIT TextCaps Base": "microsoft/git-base-textcaps", # 0.7GB
|
|
"Microsoft GIT VQA Base": "microsoft/git-base-vqav2", # 0.7GB
|
|
"Microsoft GIT VQA Large": "microsoft/git-large-vqav2", # 1.6GB
|
|
"ToriiGate 0.4 2B": "Minthy/ToriiGate-v0.4-2B",
|
|
"ToriiGate 0.4 7B": "Minthy/ToriiGate-v0.4-7B",
|
|
"ViLT Base": "dandelin/vilt-b32-finetuned-vqa", # 0.5GB
|
|
"JoyCaption Alpha": "fancyfeast/llama-joycaption-alpha-two-hf-llava", # 17.4GB
|
|
"JoyCaption Beta": "fancyfeast/llama-joycaption-beta-one-hf-llava", # 17.4GB
|
|
"JoyTag": "fancyfeast/joytag", # 0.7GB
|
|
"AIDC Ovis2 1B": "AIDC-AI/Ovis2-1B",
|
|
"AIDC Ovis2 2B": "AIDC-AI/Ovis2-2B",
|
|
"AIDC Ovis2 4B": "AIDC-AI/Ovis2-4B",
|
|
"ByteDance Sa2VA 1B": "ByteDance/Sa2VA-1B",
|
|
"ByteDance Sa2VA 4B": "ByteDance/Sa2VA-4B",
|
|
# "OpenGVLab InternVL 2.5 1B": "OpenGVLab/InternVL2_5-1B"
|
|
# "DeepSeek VL2 Tiny": "deepseek-ai/deepseek-vl2-tiny", # broken
|
|
# "nVidia Eagle 2 1B": "nvidia/Eagle2-1B", # not compatible with latest transformers
|
|
}
|
|
vlm_system = 'You are image captioning expert, creative, unbiased and uncensored.'
|
|
|
|
# Common prompts (work with all VLM models)
|
|
vlm_prompts_common = [
|
|
"Use Prompt",
|
|
"Short Caption",
|
|
"Normal Caption",
|
|
"Long Caption",
|
|
]
|
|
|
|
# Florence-2 specific prompts (only shown for Florence/PromptGen models)
|
|
vlm_prompts_florence = [
|
|
"Phrase Grounding",
|
|
"Object Detection",
|
|
"Dense Region Caption",
|
|
"Region Proposal",
|
|
"OCR (Read Text)",
|
|
"OCR with Regions",
|
|
"Analyze",
|
|
"Generate Tags",
|
|
"Mixed Caption",
|
|
"Mixed Caption+",
|
|
]
|
|
|
|
# Moondream specific prompts (only shown for Moondream models)
|
|
vlm_prompts_moondream = [
|
|
"Point at...",
|
|
"Detect all...",
|
|
]
|
|
|
|
# Mapping from friendly names to internal tokens/commands
|
|
vlm_prompt_mapping = {
|
|
"Use Prompt": "Use Prompt",
|
|
"Short Caption": "<CAPTION>",
|
|
"Normal Caption": "<DETAILED_CAPTION>",
|
|
"Long Caption": "<MORE_DETAILED_CAPTION>",
|
|
"Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
|
|
"Object Detection": "<OD>",
|
|
"Dense Region Caption": "<DENSE_REGION_CAPTION>",
|
|
"Region Proposal": "<REGION_PROPOSAL>",
|
|
"OCR (Read Text)": "<OCR>",
|
|
"OCR with Regions": "<OCR_WITH_REGION>",
|
|
"Analyze": "<ANALYZE>",
|
|
"Generate Tags": "<GENERATE_TAGS>",
|
|
"Mixed Caption": "<MIXED_CAPTION>",
|
|
"Mixed Caption+": "<MIXED_CAPTION_PLUS>",
|
|
"Point at...": "POINT_MODE",
|
|
"Detect all...": "DETECT_MODE",
|
|
}
|
|
|
|
# Placeholder hints for prompt field based on selected question
|
|
vlm_prompt_placeholders = {
|
|
"Use Prompt": "Enter your question or instruction for the model",
|
|
"Short Caption": "Optional: add specific focus or style instructions",
|
|
"Normal Caption": "Optional: add specific focus or style instructions",
|
|
"Long Caption": "Optional: add specific focus or style instructions",
|
|
"Phrase Grounding": "Optional: specify phrases to ground in the image",
|
|
"Object Detection": "Optional: specify object types to detect",
|
|
"Dense Region Caption": "Optional: add specific instructions",
|
|
"Region Proposal": "Optional: add specific instructions",
|
|
"OCR (Read Text)": "Optional: add specific instructions",
|
|
"OCR with Regions": "Optional: add specific instructions",
|
|
"Analyze": "Optional: add specific analysis instructions",
|
|
"Generate Tags": "Optional: add specific tagging instructions",
|
|
"Mixed Caption": "Optional: add specific instructions",
|
|
"Mixed Caption+": "Optional: add specific instructions",
|
|
"Point at...": "Enter objects to locate, e.g., 'the red car' or 'all the eyes'",
|
|
"Detect all...": "Enter object type to detect, e.g., 'cars' or 'faces'",
|
|
}
|
|
|
|
# Legacy list for backwards compatibility
|
|
vlm_prompts = vlm_prompts_common + vlm_prompts_florence + vlm_prompts_moondream
|
|
|
|
vlm_prefill = 'Answer: the image shows'
|
|
|
|
|
|
def get_prompts_for_model(model_name: str) -> list:
|
|
"""Get available prompts based on selected model."""
|
|
if model_name is None:
|
|
return vlm_prompts_common
|
|
|
|
model_lower = model_name.lower()
|
|
|
|
# Check for Florence-2 / PromptGen models
|
|
if 'florence' in model_lower or 'promptgen' in model_lower:
|
|
return vlm_prompts_common + vlm_prompts_florence
|
|
|
|
# Check for Moondream models
|
|
if 'moondream' in model_lower:
|
|
return vlm_prompts_common + vlm_prompts_moondream
|
|
|
|
# Default: common prompts only
|
|
return vlm_prompts_common
|
|
|
|
|
|
def get_internal_prompt(friendly_name: str, user_prompt: str = None) -> str:
|
|
"""Convert friendly prompt name to internal token/command."""
|
|
internal = vlm_prompt_mapping.get(friendly_name, friendly_name)
|
|
|
|
# Handle Moondream point/detect modes - prepend trigger phrase
|
|
if internal == "POINT_MODE" and user_prompt:
|
|
return f"Point at {user_prompt}"
|
|
elif internal == "DETECT_MODE" and user_prompt:
|
|
return f"Detect {user_prompt}"
|
|
|
|
return internal
|
|
|
|
|
|
def get_prompt_placeholder(friendly_name: str) -> str:
|
|
"""Get placeholder text for the prompt field based on selected question."""
|
|
return vlm_prompt_placeholders.get(friendly_name, "Enter your question or instruction")
|
|
|
|
|
|
def is_florence_task(question: str) -> bool:
|
|
"""Check if the question is a Florence-2 task token (either friendly name or internal token)."""
|
|
if not question:
|
|
return False
|
|
# Check if it's a Florence-specific friendly name
|
|
if question in vlm_prompts_florence:
|
|
return True
|
|
# Check if it's an internal Florence-2 task token (for backwards compatibility)
|
|
florence_tokens = ['<CAPTION>', '<DETAILED_CAPTION>', '<MORE_DETAILED_CAPTION>', '<CAPTION_TO_PHRASE_GROUNDING>',
|
|
'<OD>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR>', '<OCR_WITH_REGION>',
|
|
'<ANALYZE>', '<GENERATE_TAGS>', '<MIXED_CAPTION>', '<MIXED_CAPTION_PLUS>']
|
|
return question in florence_tokens
|
|
|
|
|
|
def is_thinking_model(model_name: str) -> bool:
|
|
"""Check if the model supports thinking mode based on its name."""
|
|
if not model_name:
|
|
return False
|
|
model_lower = model_name.lower()
|
|
# Check for known thinking models
|
|
thinking_indicators = [
|
|
'thinking', # Qwen3-VL-*-Thinking models
|
|
'moondream3', # Moondream 3 supports thinking
|
|
'moondream 3',
|
|
'mimo',
|
|
]
|
|
return any(indicator in model_lower for indicator in thinking_indicators)
|
|
|
|
|
|
def truncate_b64_in_conversation(conversation, front_chars=50, tail_chars=50, threshold=200):
|
|
"""
|
|
Deep copy a conversation structure and truncate long base64 image strings for logging.
|
|
Preserves front and tail of base64 strings with truncation indicator.
|
|
"""
|
|
conv_copy = copy.deepcopy(conversation)
|
|
|
|
def truncate_recursive(obj):
|
|
if isinstance(obj, dict):
|
|
for key, value in obj.items():
|
|
if key == "image" and isinstance(value, str) and len(value) > threshold:
|
|
# Truncate the base64 image string
|
|
truncated_count = len(value) - front_chars - tail_chars
|
|
obj[key] = f"{value[:front_chars]}...[{truncated_count} chars truncated]...{value[-tail_chars:]}"
|
|
elif isinstance(value, (dict, list)):
|
|
truncate_recursive(value)
|
|
elif isinstance(obj, list):
|
|
for item in obj:
|
|
truncate_recursive(item)
|
|
|
|
truncate_recursive(conv_copy)
|
|
return conv_copy
|
|
|
|
|
|
def keep_think_block_open(text_prompt: str) -> str:
|
|
"""Remove the closing </think> of the final assistant message so the model can continue reasoning."""
|
|
think_open = "<think>"
|
|
think_close = "</think>"
|
|
last_open = text_prompt.rfind(think_open)
|
|
if last_open == -1:
|
|
return text_prompt
|
|
close_index = text_prompt.find(think_close, last_open)
|
|
if close_index == -1:
|
|
return text_prompt
|
|
# Skip any whitespace immediately following the closing tag
|
|
end_close = close_index + len(think_close)
|
|
while end_close < len(text_prompt) and text_prompt[end_close] in (' ', '\t'):
|
|
end_close += 1
|
|
while end_close < len(text_prompt) and text_prompt[end_close] in ('\r', '\n'):
|
|
end_close += 1
|
|
trimmed_prompt = text_prompt[:close_index] + text_prompt[end_close:]
|
|
debug('VQA interrogate: keep_think_block_open applied to prompt segment near assistant reply')
|
|
return trimmed_prompt
|
|
|
|
|
|
def b64(image):
|
|
if image is None:
|
|
return ''
|
|
with io.BytesIO() as stream:
|
|
image.save(stream, 'JPEG')
|
|
values = stream.getvalue()
|
|
encoded = base64.b64encode(values).decode()
|
|
return encoded
|
|
|
|
|
|
def clean(response, question, prefill=None):
|
|
strip = ['---', '\r', '\t', '**', '"', '"', '"', 'Assistant:', 'Caption:', '<|im_end|>', '<pad>']
|
|
if isinstance(response, str):
|
|
response = response.strip()
|
|
elif isinstance(response, dict):
|
|
text_response = ""
|
|
if 'reasoning' in response and shared.opts.interrogate_vlm_keep_thinking:
|
|
r_text = response['reasoning']
|
|
if isinstance(r_text, dict) and 'text' in r_text:
|
|
r_text = r_text['text']
|
|
text_response += f"Reasoning:\n{r_text}\nAnswer:\n"
|
|
|
|
if 'answer' in response:
|
|
text_response += response['answer']
|
|
elif 'caption' in response:
|
|
text_response += response['caption']
|
|
elif 'task' in response:
|
|
text_response += response['task']
|
|
else:
|
|
if not text_response:
|
|
text_response = json.dumps(response)
|
|
response = text_response
|
|
elif isinstance(response, list):
|
|
response = response[0]
|
|
else:
|
|
response = str(response)
|
|
|
|
# Determine prefill text
|
|
prefill_text = vlm_prefill if prefill is None else prefill
|
|
if prefill_text is None: prefill_text = ""
|
|
prefill_text = prefill_text.strip()
|
|
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
if question in response:
|
|
response = response.split(question, 1)[1]
|
|
while any(s in response for s in strip):
|
|
for s in strip:
|
|
response = response.replace(s, '')
|
|
response = response.replace('\n\n', '\n').replace(' ', ' ').replace('* ', '- ').strip()
|
|
|
|
# Handle prefill retention/removal
|
|
if shared.opts.interrogate_vlm_keep_prefill:
|
|
# Add prefill if it's missing from the cleaned response
|
|
if len(prefill_text) > 0 and not response.startswith(prefill_text):
|
|
sep = " "
|
|
if not response or response[0] in ".,!?;:":
|
|
sep = ""
|
|
response = f"{prefill_text}{sep}{response}"
|
|
else:
|
|
# Remove prefill if it's present in the cleaned response
|
|
if len(prefill_text) > 0 and response.startswith(prefill_text):
|
|
response = response[len(prefill_text):].strip()
|
|
|
|
return response
|
|
|
|
|
|
def get_kwargs():
|
|
kwargs = {
|
|
'max_new_tokens': shared.opts.interrogate_vlm_max_length,
|
|
'do_sample': shared.opts.interrogate_vlm_do_sample,
|
|
}
|
|
if shared.opts.interrogate_vlm_num_beams > 0:
|
|
kwargs['num_beams'] = shared.opts.interrogate_vlm_num_beams
|
|
if shared.opts.interrogate_vlm_temperature > 0:
|
|
kwargs['temperature'] = shared.opts.interrogate_vlm_temperature
|
|
if shared.opts.interrogate_vlm_top_k > 0:
|
|
kwargs['top_k'] = shared.opts.interrogate_vlm_top_k
|
|
if shared.opts.interrogate_vlm_top_p > 0:
|
|
kwargs['top_p'] = shared.opts.interrogate_vlm_top_p
|
|
return kwargs
|
|
|
|
|
|
def draw_bounding_boxes(image: Image.Image, detections: list, points: list = None) -> Image.Image:
|
|
"""
|
|
Draw bounding boxes and/or points on an image.
|
|
|
|
Args:
|
|
image: PIL Image to annotate
|
|
detections: List of detection dicts with format:
|
|
[{'label': str, 'bbox': [x1, y1, x2, y2], 'confidence': float}, ...]
|
|
where coordinates are normalized 0-1
|
|
points: Optional list of (x, y) tuples with normalized 0-1 coordinates
|
|
|
|
Returns:
|
|
Annotated PIL Image with boxes and labels drawn
|
|
"""
|
|
if not detections and not points:
|
|
return None
|
|
|
|
# Create a copy to avoid modifying original
|
|
annotated = image.copy()
|
|
draw = ImageDraw.Draw(annotated)
|
|
width, height = image.size
|
|
|
|
# Try to load a font, fall back to default if unavailable
|
|
try:
|
|
font_size = max(12, int(min(width, height) * 0.02))
|
|
font_path = shared.opts.font or "javascript/notosans-nerdfont-regular.ttf"
|
|
font = ImageFont.truetype(font_path, size=font_size)
|
|
except Exception:
|
|
font = ImageFont.load_default()
|
|
|
|
# Draw bounding boxes
|
|
if detections:
|
|
colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#FF00FF', '#00FFFF', '#FFA500', '#800080']
|
|
for idx, det in enumerate(detections):
|
|
bbox = det['bbox']
|
|
label = det.get('label', 'object')
|
|
confidence = det.get('confidence', 1.0)
|
|
|
|
# Convert normalized coordinates to pixel coordinates
|
|
x1 = int(bbox[0] * width)
|
|
y1 = int(bbox[1] * height)
|
|
x2 = int(bbox[2] * width)
|
|
y2 = int(bbox[3] * height)
|
|
|
|
# Choose color
|
|
color = colors[idx % len(colors)]
|
|
|
|
# Draw box
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=max(2, int(min(width, height) * 0.003)))
|
|
|
|
# Draw label with background
|
|
label_text = f"{label} {confidence:.2f}" if confidence < 1.0 else label
|
|
bbox_font = draw.textbbox((x1, y1), label_text, font=font)
|
|
text_width = bbox_font[2] - bbox_font[0]
|
|
text_height = bbox_font[3] - bbox_font[1]
|
|
draw.rectangle([x1, y1 - text_height - 4, x1 + text_width + 4, y1], fill=color)
|
|
draw.text((x1 + 2, y1 - text_height - 2), label_text, fill='white', font=font)
|
|
|
|
# Draw points
|
|
if points:
|
|
point_radius = max(3, int(min(width, height) * 0.01))
|
|
for px, py in points:
|
|
x = int(px * width)
|
|
y = int(py * height)
|
|
# Draw point as a circle
|
|
draw.ellipse([x - point_radius, y - point_radius, x + point_radius, y + point_radius],
|
|
fill='#FF0000', outline='#FFFFFF', width=2)
|
|
|
|
return annotated
|
|
|
|
|
|
def fastvlm(question: str, image: Image.Image, repo: str = None, model_name: str = None):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
debug(f'VQA interrogate: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}')
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'VQA Interrogate load: vlm="{repo}"')
|
|
model = None
|
|
processor = transformers.AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
# device_map="auto",
|
|
trust_remote_code=True,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
if len(question) < 2:
|
|
question = "Describe the image."
|
|
question = question.replace('<', '').replace('>', '')
|
|
IMAGE_TOKEN_INDEX = -200 # what the model code looks for
|
|
messages = [{"role": "user", "content": f"<image>\n{question}"}]
|
|
rendered = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
|
pre, post = rendered.split("<image>", 1)
|
|
pre_ids = processor(pre, return_tensors="pt", add_special_tokens=False).input_ids
|
|
post_ids = processor(post, return_tensors="pt", add_special_tokens=False).input_ids
|
|
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
|
|
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1)
|
|
input_ids = input_ids.to(devices.device)
|
|
attention_mask = torch.ones_like(input_ids, device=devices.device)
|
|
px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")
|
|
px = px["pixel_values"].to(model.device, dtype=model.dtype)
|
|
with devices.inference_context():
|
|
outputs = model.generate(
|
|
inputs=input_ids,
|
|
attention_mask=attention_mask,
|
|
images=px,
|
|
max_new_tokens=128,
|
|
)
|
|
answer = processor.decode(outputs[0], skip_special_tokens=True)
|
|
return answer
|
|
|
|
|
|
def qwen(
|
|
question: str,
|
|
image: Image.Image,
|
|
repo: str = None,
|
|
system_prompt: str = None,
|
|
model_name: str = None,
|
|
prefill: str = None,
|
|
thinking_mode: bool = False,
|
|
):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
model = None
|
|
if 'Qwen3-VL' in repo or 'Qwen3VL' in repo:
|
|
cls_name = transformers.Qwen3VLForConditionalGeneration
|
|
elif 'Qwen2.5-VL' in repo or 'Qwen2_5_VL' in repo or 'MiMo-VL' in repo:
|
|
cls_name = transformers.Qwen2_5_VLForConditionalGeneration
|
|
elif 'Qwen2-VL' in repo or 'Qwen2VL' in repo:
|
|
cls_name = transformers.Qwen2VLForConditionalGeneration
|
|
else:
|
|
cls_name = transformers.AutoModelForCausalLM
|
|
model = cls_name.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
|
|
if 'LLM' in shared.opts.cuda_compile:
|
|
model = sd_models_compile.compile_torch(model)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
# Get model class name for logging
|
|
cls_name = model.__class__.__name__
|
|
debug(f'VQA interrogate: handler=qwen model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
|
|
|
# Warn if using Florence-2 task tokens with non-Florence-2 models
|
|
if is_florence_task(question):
|
|
shared.log.warning(f'Interrogate: Florence-2 task token "{question}" is designed for Florence-2 models. Using it anyway, but results may vary.')
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
|
|
conversation = [
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": system_prompt}],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "image": b64(image)},
|
|
{"type": "text", "text": question},
|
|
],
|
|
}
|
|
]
|
|
# Add prefill for all models (only if provided)
|
|
prefill_value = vlm_prefill if prefill is None else prefill
|
|
prefill_text = prefill_value.strip()
|
|
|
|
# Thinking models emit their own <think> tags via the chat template
|
|
# Use manual toggle OR auto-detection based on model name
|
|
is_thinking = is_thinking_model(model_name)
|
|
use_thinking = thinking_mode or is_thinking
|
|
|
|
# Standardize prefill
|
|
prefill_value = vlm_prefill if prefill is None else prefill
|
|
prefill_text = prefill_value.strip()
|
|
use_prefill = len(prefill_text) > 0
|
|
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=qwen conversation_roles={[msg["role"] for msg in conversation]}')
|
|
debug(f'VQA interrogate: handler=qwen full_conversation={truncate_b64_in_conversation(conversation)}')
|
|
debug(f'VQA interrogate: handler=qwen is_thinking={is_thinking} thinking_mode={thinking_mode} prefill="{prefill_text}"')
|
|
|
|
# Generate base prompt using template
|
|
# Qwen-Thinking template automatically adds "<|im_start|>assistant\n<think>\n" when add_generation_prompt=True
|
|
try:
|
|
text_prompt = processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=True,
|
|
)
|
|
except (TypeError, ValueError) as e:
|
|
debug(f'VQA interrogate: handler=qwen chat_template fallback add_generation_prompt=True: {e}')
|
|
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
|
|
|
# Manually handle thinking tags and prefill
|
|
if is_thinking:
|
|
if not thinking_mode:
|
|
# User wants to SKIP thinking.
|
|
# Since template opened the block with <think>, we close it immediately.
|
|
text_prompt += "</think>\n"
|
|
if use_prefill:
|
|
text_prompt += prefill_text
|
|
else:
|
|
# User wants thinking. Prompt already ends in <think>.
|
|
# If prefill is provided, it becomes part of the thought process.
|
|
if use_prefill:
|
|
text_prompt += prefill_text
|
|
else:
|
|
# Standard model (not forcing <think>)
|
|
if use_prefill:
|
|
text_prompt += prefill_text
|
|
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=qwen text_prompt="{text_prompt}"')
|
|
inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
|
|
inputs = inputs.to(devices.device, devices.dtype)
|
|
gen_kwargs = get_kwargs()
|
|
debug(f'VQA interrogate: handler=qwen generation_kwargs={gen_kwargs} input_ids_shape={inputs.input_ids.shape}')
|
|
output_ids = model.generate(
|
|
**inputs,
|
|
**gen_kwargs,
|
|
)
|
|
debug(f'VQA interrogate: handler=qwen output_ids_shape={output_ids.shape}')
|
|
generated_ids = [
|
|
output_ids[len(input_ids) :]
|
|
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
|
|
]
|
|
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=qwen response_before_clean="{response}"')
|
|
# Clean up thinking tags
|
|
if len(response) > 0:
|
|
text = response[0]
|
|
if shared.opts.interrogate_vlm_keep_thinking:
|
|
text = text.replace('<think>', 'Reasoning:\n').replace('</think>', '\nAnswer:')
|
|
else:
|
|
while '</think>' in text:
|
|
start = text.find('<think>')
|
|
end = text.find('</think>')
|
|
|
|
if start != -1 and start < end:
|
|
# Standard <think>...content...</think> block
|
|
text = text[:start] + text[end+8:]
|
|
else:
|
|
# Missing <think> (implied at start) or malformed
|
|
# Remove from start up to </think>
|
|
text = text[end+8:]
|
|
response[0] = text
|
|
return response
|
|
|
|
|
|
def gemma(
|
|
question: str,
|
|
image: Image.Image,
|
|
repo: str = None,
|
|
system_prompt: str = None,
|
|
model_name: str = None,
|
|
prefill: str = None,
|
|
thinking_mode: bool = False,
|
|
):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
model = None
|
|
if '3n' in repo:
|
|
cls = transformers.Gemma3nForConditionalGeneration # pylint: disable=no-member
|
|
else:
|
|
cls = transformers.Gemma3ForConditionalGeneration
|
|
model = cls.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
if 'LLM' in shared.opts.cuda_compile:
|
|
model = sd_models_compile.compile_torch(model)
|
|
processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
# Get model class name for logging
|
|
cls_name = model.__class__.__name__
|
|
debug(f'VQA interrogate: handler=gemma model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
|
|
|
# Warn if using Florence-2 task tokens with non-Florence-2 models
|
|
if is_florence_task(question):
|
|
shared.log.warning(f'Interrogate: Florence-2 task token "{question}" is designed for Florence-2 models. Using it anyway, but results may vary.')
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
|
|
|
|
system_content = []
|
|
if system_prompt is not None and len(system_prompt) > 4:
|
|
system_content.append({"type": "text", "text": system_prompt})
|
|
|
|
user_content = []
|
|
if question is not None and len(question) > 4:
|
|
user_content.append({"type": "text", "text": question})
|
|
if image is not None:
|
|
user_content.append({"type": "image", "image": b64(image)})
|
|
conversation = [
|
|
{"role": "system", "content": system_content},
|
|
{"role": "user", "content": user_content},
|
|
]
|
|
# Add prefill for all models (only if provided)
|
|
prefill_value = vlm_prefill if prefill is None else prefill
|
|
prefill_text = prefill_value.strip()
|
|
use_prefill = len(prefill_text) > 0
|
|
# Thinking models emit their own <think> tags via the chat template
|
|
# Use manual toggle OR auto-detection based on model name
|
|
use_thinking = thinking_mode or is_thinking_model(model_name)
|
|
if use_prefill:
|
|
conversation.append({
|
|
"role": "assistant",
|
|
"content": [{"type": "text", "text": prefill_text}],
|
|
})
|
|
debug(f'VQA interrogate: handler=gemma prefill="{prefill_text}"')
|
|
else:
|
|
debug('VQA interrogate: handler=gemma prefill disabled (empty), relying on add_generation_prompt')
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=gemma conversation_roles={[msg["role"] for msg in conversation]}')
|
|
debug(f'VQA interrogate: handler=gemma full_conversation={truncate_b64_in_conversation(conversation)}')
|
|
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
|
|
debug(f'VQA interrogate: handler=gemma template_mode={debug_prefill_mode}')
|
|
try:
|
|
if use_prefill:
|
|
text_prompt = processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=False,
|
|
continue_final_message=True,
|
|
tokenize=False,
|
|
)
|
|
else:
|
|
text_prompt = processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=True,
|
|
tokenize=False,
|
|
)
|
|
except (TypeError, ValueError) as e:
|
|
debug(f'VQA interrogate: handler=gemma chat_template fallback add_generation_prompt=True: {e}')
|
|
text_prompt = processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=True,
|
|
tokenize=False,
|
|
)
|
|
if use_prefill and use_thinking:
|
|
text_prompt = keep_think_block_open(text_prompt)
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=gemma text_prompt="{text_prompt}"')
|
|
inputs = processor(
|
|
text=[text_prompt],
|
|
images=[image],
|
|
padding=True,
|
|
return_tensors="pt",
|
|
).to(device=devices.device, dtype=devices.dtype)
|
|
input_len = inputs["input_ids"].shape[-1]
|
|
gen_kwargs = get_kwargs()
|
|
debug(f'VQA interrogate: handler=gemma generation_kwargs={gen_kwargs} input_len={input_len}')
|
|
with devices.inference_context():
|
|
generation = model.generate(
|
|
**inputs,
|
|
**gen_kwargs,
|
|
)
|
|
debug(f'VQA interrogate: handler=gemma output_ids_shape={generation.shape}')
|
|
generation = generation[0][input_len:]
|
|
response = processor.decode(generation, skip_special_tokens=True)
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=gemma response_before_clean="{response}"')
|
|
|
|
# Clean up thinking tags (if any remain)
|
|
if shared.opts.interrogate_vlm_keep_thinking:
|
|
response = response.replace('<think>', 'Reasoning:\n').replace('</think>', '\nAnswer:')
|
|
else:
|
|
text = response
|
|
while '</think>' in text:
|
|
start = text.find('<think>')
|
|
end = text.find('</think>')
|
|
if start != -1 and start < end:
|
|
text = text[:start] + text[end+8:]
|
|
else:
|
|
text = text[end+8:]
|
|
response = text
|
|
|
|
return response
|
|
|
|
|
|
def paligemma(question: str, image: Image.Image, repo: str = None, model_name: str = None):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
processor = transformers.PaliGemmaProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
model = None
|
|
model = transformers.PaliGemmaForConditionalGeneration.from_pretrained(
|
|
repo,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
torch_dtype=devices.dtype,
|
|
)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
model_inputs = processor(text=question, images=image, return_tensors="pt").to(devices.device, devices.dtype)
|
|
input_len = model_inputs["input_ids"].shape[-1]
|
|
with devices.inference_context():
|
|
generation = model.generate(
|
|
**model_inputs,
|
|
**get_kwargs(),
|
|
)
|
|
generation = generation[0][input_len:]
|
|
response = processor.decode(generation, skip_special_tokens=True)
|
|
return response
|
|
|
|
|
|
def ovis(question: str, image: Image.Image, repo: str = None, model_name: str = None):
|
|
try:
|
|
import flash_attn # pylint: disable=unused-import
|
|
except Exception:
|
|
shared.log.error(f'Interrogate: vlm="{repo}" flash-attn is not available')
|
|
return ''
|
|
global model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
model = None
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
multimodal_max_length=32768,
|
|
trust_remote_code=True,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
text_tokenizer = model.get_text_tokenizer()
|
|
visual_tokenizer = model.get_visual_tokenizer()
|
|
max_partition = 9
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
question = f'<image>\n{question}'
|
|
_prompt, input_ids, pixel_values = model.preprocess_inputs(question, [image], max_partition=max_partition)
|
|
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
|
|
input_ids = input_ids.unsqueeze(0).to(device=model.device)
|
|
attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
|
|
if pixel_values is not None:
|
|
pixel_values = pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)
|
|
pixel_values = [pixel_values]
|
|
with devices.inference_context():
|
|
output_ids = model.generate(
|
|
input_ids,
|
|
pixel_values=pixel_values,
|
|
attention_mask=attention_mask,
|
|
repetition_penalty=None,
|
|
eos_token_id=model.generation_config.eos_token_id,
|
|
pad_token_id=text_tokenizer.pad_token_id,
|
|
use_cache=True,
|
|
**get_kwargs())
|
|
response = text_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
print(f'Output:\n{response}')
|
|
return response
|
|
|
|
|
|
def smol(
|
|
question: str,
|
|
image: Image.Image,
|
|
repo: str = None,
|
|
system_prompt: str = None,
|
|
model_name: str = None,
|
|
prefill: str = None,
|
|
thinking_mode: bool = False,
|
|
):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
model = None
|
|
model = transformers.AutoModelForVision2Seq.from_pretrained(
|
|
repo,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
torch_dtype=devices.dtype,
|
|
**quant_args,
|
|
)
|
|
processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
|
|
if 'LLM' in shared.opts.cuda_compile:
|
|
model = sd_models_compile.compile_torch(model)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
# Get model class name for logging
|
|
cls_name = model.__class__.__name__
|
|
debug(f'VQA interrogate: handler=smol model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
|
|
|
# Warn if using Florence-2 task tokens with non-Florence-2 models
|
|
if is_florence_task(question):
|
|
shared.log.warning(f'Interrogate: Florence-2 task token "{question}" is designed for Florence-2 models. Using it anyway, but results may vary.')
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
|
|
conversation = [
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": system_prompt}],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "image": b64(image)},
|
|
{"type": "text", "text": question},
|
|
],
|
|
}
|
|
]
|
|
# Add prefill for all models (only if provided)
|
|
prefill_value = vlm_prefill if prefill is None else prefill
|
|
prefill_text = prefill_value.strip()
|
|
use_prefill = len(prefill_text) > 0
|
|
# Thinking models emit their own <think> tags via the chat template
|
|
# Use manual toggle OR auto-detection based on model name
|
|
use_thinking = thinking_mode or is_thinking_model(model_name)
|
|
if use_prefill:
|
|
conversation.append({
|
|
"role": "assistant",
|
|
"content": [{"type": "text", "text": prefill_text}],
|
|
})
|
|
debug(f'VQA interrogate: handler=smol prefill="{prefill_text}"')
|
|
else:
|
|
debug('VQA interrogate: handler=smol prefill disabled (empty), relying on add_generation_prompt')
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=smol conversation_roles={[msg["role"] for msg in conversation]}')
|
|
debug(f'VQA interrogate: handler=smol full_conversation={truncate_b64_in_conversation(conversation)}')
|
|
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
|
|
debug(f'VQA interrogate: handler=smol template_mode={debug_prefill_mode}')
|
|
try:
|
|
if use_prefill:
|
|
text_prompt = processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=False,
|
|
continue_final_message=True,
|
|
)
|
|
else:
|
|
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
|
except (TypeError, ValueError) as e:
|
|
# Fallback for models that don't support continue_final_message or for mismatched kwargs
|
|
debug(f'VQA interrogate: handler=smol chat_template fallback add_generation_prompt=True: {e}')
|
|
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
|
if use_prefill and use_thinking:
|
|
text_prompt = keep_think_block_open(text_prompt)
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=smol text_prompt="{text_prompt}"')
|
|
inputs = processor(text=text_prompt, images=[image], padding=True, return_tensors="pt")
|
|
inputs = inputs.to(devices.device, devices.dtype)
|
|
gen_kwargs = get_kwargs()
|
|
debug(f'VQA interrogate: handler=smol generation_kwargs={gen_kwargs}')
|
|
output_ids = model.generate(
|
|
**inputs,
|
|
**gen_kwargs,
|
|
)
|
|
debug(f'VQA interrogate: handler=smol output_ids_shape={output_ids.shape}')
|
|
response = processor.batch_decode(output_ids,skip_special_tokens=True)
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=smol response_before_clean="{response}"')
|
|
|
|
# Clean up thinking tags
|
|
if len(response) > 0:
|
|
text = response[0]
|
|
if shared.opts.interrogate_vlm_keep_thinking:
|
|
text = text.replace('<think>', 'Reasoning:\n').replace('</think>', '\nAnswer:')
|
|
else:
|
|
while '</think>' in text:
|
|
start = text.find('<think>')
|
|
end = text.find('</think>')
|
|
if start != -1 and start < end:
|
|
text = text[:start] + text[end+8:]
|
|
else:
|
|
text = text[end+8:]
|
|
response[0] = text
|
|
|
|
return response
|
|
|
|
|
|
def git(question: str, image: Image.Image, repo: str = None, model_name: str = None):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
model = None
|
|
model = transformers.GitForCausalLM.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
processor = transformers.GitProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
|
git_dict = {}
|
|
git_dict['pixel_values'] = pixel_values.to(devices.device, devices.dtype)
|
|
if len(question) > 0:
|
|
input_ids = processor(text=question, add_special_tokens=False).input_ids
|
|
input_ids = [processor.tokenizer.cls_token_id] + input_ids
|
|
input_ids = torch.tensor(input_ids).unsqueeze(0)
|
|
git_dict['input_ids'] = input_ids.to(devices.device)
|
|
with devices.inference_context():
|
|
generated_ids = model.generate(**git_dict)
|
|
response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
return response
|
|
|
|
|
|
def blip(question: str, image: Image.Image, repo: str = None, model_name: str = None):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
model = None
|
|
model = transformers.BlipForQuestionAnswering.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
processor = transformers.BlipProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
inputs = processor(image, question, return_tensors="pt")
|
|
inputs = inputs.to(devices.device, devices.dtype)
|
|
with devices.inference_context():
|
|
outputs = model.generate(**inputs)
|
|
response = processor.decode(outputs[0], skip_special_tokens=True)
|
|
return response
|
|
|
|
|
|
def vilt(question: str, image: Image.Image, repo: str = None, model_name: str = None):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
model = None
|
|
model = transformers.ViltForQuestionAnswering.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
processor = transformers.ViltProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
inputs = processor(image, question, return_tensors="pt")
|
|
inputs = inputs.to(devices.device)
|
|
with devices.inference_context():
|
|
outputs = model(**inputs)
|
|
logits = outputs.logits
|
|
idx = logits.argmax(-1).item()
|
|
response = model.config.id2label[idx]
|
|
return response
|
|
|
|
|
|
def pix(question: str, image: Image.Image, repo: str = None, model_name: str = None):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
model = None
|
|
model = transformers.Pix2StructForConditionalGeneration.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
processor = transformers.Pix2StructProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
if len(question) > 0:
|
|
inputs = processor(images=image, text=question, return_tensors="pt").to(devices.device)
|
|
else:
|
|
inputs = processor(images=image, return_tensors="pt").to(devices.device)
|
|
with devices.inference_context():
|
|
outputs = model.generate(**inputs)
|
|
response = processor.decode(outputs[0], skip_special_tokens=True)
|
|
return response
|
|
|
|
|
|
def moondream(question: str, image: Image.Image, repo: str = None, model_name: str = None):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
model = None
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
repo,
|
|
revision="2025-06-21",
|
|
trust_remote_code=True,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
processor = transformers.AutoTokenizer.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
loaded = repo
|
|
model.eval()
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
encoded = model.encode_image(image)
|
|
with devices.inference_context():
|
|
if question == 'CAPTION':
|
|
response = model.caption(image, length="short")['caption']
|
|
elif question == 'DETAILED CAPTION':
|
|
response = model.caption(image, length="normal")['caption']
|
|
elif question == 'MORE DETAILED CAPTION':
|
|
response = model.caption(image, length="long")['caption']
|
|
else:
|
|
response = model.answer_question(encoded, question, processor)['answer']
|
|
# model.detect(image, "face")
|
|
# model.point(image, "person")
|
|
# model.detect_gaze(image)
|
|
return response
|
|
|
|
|
|
def florence(question: str, image: Image.Image, repo: str = None, revision: str = None, model_name: str = None):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
_get_imports = transformers.dynamic_module_utils.get_imports
|
|
|
|
def get_imports(f):
|
|
R = _get_imports(f)
|
|
if "flash_attn" in R:
|
|
R.remove("flash_attn") # flash_attn is optional
|
|
return R
|
|
|
|
# Handle revision splitting and caching
|
|
cache_key = repo
|
|
effective_revision = revision
|
|
repo_name = repo
|
|
|
|
if repo and '@' in repo:
|
|
repo_name, revision_from_repo = repo.split('@')
|
|
effective_revision = revision_from_repo
|
|
|
|
if model is None or loaded != cache_key:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo_name}" revision="{effective_revision}" path="{shared.opts.hfcache_dir}"')
|
|
transformers.dynamic_module_utils.get_imports = get_imports
|
|
model = None
|
|
"""
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
repo_name,
|
|
trust_remote_code=True,
|
|
revision=effective_revision,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
"""
|
|
model = transformers.Florence2ForConditionalGeneration.from_pretrained(
|
|
repo_name,
|
|
dtype=torch.bfloat16,
|
|
revision=effective_revision,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
processor = transformers.AutoProcessor.from_pretrained(repo_name, max_pixels=1024*1024, trust_remote_code=True, revision=effective_revision, cache_dir=shared.opts.hfcache_dir)
|
|
transformers.dynamic_module_utils.get_imports = _get_imports
|
|
loaded = cache_key
|
|
model.eval()
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
if question.startswith('<'):
|
|
task = question.split('>', 1)[0] + '>'
|
|
else:
|
|
task = '<MORE_DETAILED_CAPTION>'
|
|
inputs = processor(text=task, images=image, return_tensors="pt")
|
|
input_ids = inputs['input_ids'].to(devices.device)
|
|
pixel_values = inputs['pixel_values'].to(devices.device, devices.dtype)
|
|
with devices.inference_context():
|
|
generated_ids = model.generate(
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
**get_kwargs()
|
|
)
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
|
response = processor.post_process_generation(generated_text, task="task", image_size=(image.width, image.height))
|
|
return response
|
|
|
|
|
|
def sa2(question: str, image: Image.Image, repo: str = None, model_name: str = None):
|
|
global processor, model, loaded # pylint: disable=global-statement
|
|
if model is None or loaded != repo:
|
|
model = None
|
|
model = transformers.AutoModel.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
low_cpu_mem_usage=True,
|
|
use_flash_attn=False,
|
|
trust_remote_code=True)
|
|
model = model.eval()
|
|
processor = transformers.AutoTokenizer.from_pretrained(
|
|
repo,
|
|
trust_remote_code=True,
|
|
use_fast=False,
|
|
)
|
|
loaded = repo
|
|
devices.torch_gc()
|
|
sd_models.move_model(model, devices.device)
|
|
if question.startswith('<'):
|
|
task = question.split('>', 1)[0] + '>'
|
|
else:
|
|
task = '<MORE_DETAILED_CAPTION>'
|
|
input_dict = {
|
|
'image': image,
|
|
'text': f'<image>{task}',
|
|
'past_text': '',
|
|
'mask_prompts': None,
|
|
'tokenizer': processor,
|
|
}
|
|
return_dict = model.predict_forward(**input_dict)
|
|
response = return_dict["prediction"] # the text format answer
|
|
return response
|
|
|
|
|
|
def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:Image.Image=None, model_name:str=None, prefill:str=None, thinking_mode:bool=False, quiet:bool=False):
|
|
global quant_args # pylint: disable=global-statement
|
|
jobid = shared.state.begin('Interrogate LLM')
|
|
t0 = time.time()
|
|
quant_args = model_quant.create_config(module='LLM')
|
|
model_name = model_name or shared.opts.interrogate_vlm_model
|
|
prefill = vlm_prefill if prefill is None else prefill # Use provided prefill when specified
|
|
if isinstance(image, list):
|
|
image = image[0] if len(image) > 0 else None
|
|
if isinstance(image, dict) and 'name' in image:
|
|
image = Image.open(image['name'])
|
|
if isinstance(image, Image.Image):
|
|
if image.width > 768 or image.height > 768:
|
|
image.thumbnail((768, 768), Image.Resampling.LANCZOS)
|
|
if image.mode != 'RGB':
|
|
image = image.convert('RGB')
|
|
if image is None:
|
|
shared.log.error(f'VQA interrogate: model="{model_name}" error="No input image provided"')
|
|
return ('Error: No input image provided. Please upload or select an image.', None)
|
|
|
|
# Convert friendly prompt names to internal tokens/commands
|
|
if question == "Use Prompt":
|
|
# Use content from Prompt field directly
|
|
question = prompt if (prompt is not None and len(prompt) > 0) else ""
|
|
elif question in vlm_prompt_mapping:
|
|
# Check if this is a mode that requires user input (Point/Detect)
|
|
raw_mapping = vlm_prompt_mapping.get(question)
|
|
if raw_mapping in ("POINT_MODE", "DETECT_MODE"):
|
|
# These modes require user input in the prompt field
|
|
if not prompt or len(prompt.strip()) < 2:
|
|
shared.log.error(f'VQA interrogate: model="{model_name}" error="Please specify what to find in the prompt field"')
|
|
return ('Error: Please specify what to find in the prompt field (e.g., "the red car" or "faces").', None)
|
|
# Convert friendly name to internal token (handles Point/Detect prefix)
|
|
question = get_internal_prompt(question, prompt)
|
|
# else: question is already an internal token or custom text
|
|
|
|
# Fallback for empty questions
|
|
if len(question) < 2:
|
|
question = "Describe the image."
|
|
|
|
"""
|
|
if shared.sd_loaded:
|
|
from modules.sd_models import apply_balanced_offload # prevent circular import
|
|
apply_balanced_offload(shared.sd_model)
|
|
"""
|
|
|
|
from modules import modelloader
|
|
modelloader.hf_login()
|
|
|
|
try:
|
|
if model_name is None:
|
|
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected')
|
|
return ''
|
|
vqa_model = vlm_models.get(model_name, None)
|
|
if vqa_model is None:
|
|
shared.log.error(f'Interrogate: type=vlm model="{model_name}" unknown')
|
|
return ''
|
|
# if image is None:
|
|
# shared.log.error(f'Interrogate: type=vlm model="{model_name}" no input image')
|
|
# return ''
|
|
|
|
handler = 'unknown'
|
|
if 'git' in vqa_model.lower():
|
|
handler = 'git'
|
|
answer = git(question, image, vqa_model, model_name)
|
|
elif 'vilt' in vqa_model.lower():
|
|
handler = 'vilt'
|
|
answer = vilt(question, image, vqa_model, model_name)
|
|
elif 'blip' in vqa_model.lower():
|
|
handler = 'blip'
|
|
answer = blip(question, image, vqa_model, model_name)
|
|
elif 'pix' in vqa_model.lower():
|
|
handler = 'pix'
|
|
answer = pix(question, image, vqa_model, model_name)
|
|
elif 'moondream3' in vqa_model.lower():
|
|
handler = 'moondream3'
|
|
from modules.interrogate import moondream3
|
|
answer = moondream3.predict(question, image, vqa_model, model_name, thinking_mode=thinking_mode)
|
|
elif 'moondream2' in vqa_model.lower():
|
|
handler = 'moondream'
|
|
answer = moondream(question, image, vqa_model, model_name)
|
|
elif 'florence' in vqa_model.lower():
|
|
handler = 'florence'
|
|
answer = florence(question, image, vqa_model, None, model_name)
|
|
elif 'qwen' in vqa_model.lower() or 'torii' in vqa_model.lower() or 'mimo' in vqa_model.lower():
|
|
handler = 'qwen'
|
|
answer = qwen(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
|
|
elif 'smol' in vqa_model.lower():
|
|
handler = 'smol'
|
|
answer = smol(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
|
|
elif 'joytag' in vqa_model.lower():
|
|
handler = 'joytag'
|
|
from modules.interrogate import joytag
|
|
answer = joytag.predict(image)
|
|
elif 'joycaption' in vqa_model.lower():
|
|
handler = 'joycaption'
|
|
from modules.interrogate import joycaption
|
|
answer = joycaption.predict(question, image, vqa_model)
|
|
elif 'deepseek' in vqa_model.lower():
|
|
handler = 'deepseek'
|
|
from modules.interrogate import deepseek
|
|
answer = deepseek.predict(question, image, vqa_model)
|
|
elif 'paligemma' in vqa_model.lower():
|
|
handler = 'paligemma'
|
|
answer = paligemma(question, image, vqa_model, model_name)
|
|
elif 'gemma' in vqa_model.lower():
|
|
handler = 'gemma'
|
|
answer = gemma(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
|
|
elif 'ovis' in vqa_model.lower():
|
|
handler = 'ovis'
|
|
answer = ovis(question, image, vqa_model, model_name)
|
|
elif 'sa2' in vqa_model.lower():
|
|
handler = 'sa2'
|
|
answer = sa2(question, image, vqa_model, model_name)
|
|
elif 'fastvlm' in vqa_model.lower():
|
|
handler = 'fastvlm'
|
|
answer = fastvlm(question, image, vqa_model, model_name)
|
|
else:
|
|
answer = 'unknown model'
|
|
except Exception as e:
|
|
errors.display(e, 'VQA')
|
|
answer = 'error'
|
|
|
|
if shared.opts.interrogate_offload and model is not None:
|
|
sd_models.move_model(model, devices.cpu, force=True)
|
|
devices.torch_gc(force=True, reason='vqa')
|
|
|
|
# Handle tuple returns with detection data
|
|
annotated_image = None
|
|
if isinstance(answer, tuple) and len(answer) == 2:
|
|
text, data_dict = answer
|
|
text = clean(text, question, prefill)
|
|
# Draw bounding boxes or points if available
|
|
if data_dict and isinstance(data_dict, dict) and image:
|
|
detections = data_dict.get('detections', None)
|
|
points = data_dict.get('points', None)
|
|
if detections or points:
|
|
annotated_image = draw_bounding_boxes(image, detections or [], points)
|
|
debug(f'VQA interrogate: handler={handler} created annotated image detections={len(detections) if detections else 0} points={len(points) if points else 0}')
|
|
answer = text
|
|
else:
|
|
answer = clean(answer, question, prefill)
|
|
|
|
debug(f'VQA interrogate: handler={handler} response_after_clean="{answer}" has_annotation={annotated_image is not None}')
|
|
t1 = time.time()
|
|
if not quiet:
|
|
shared.log.debug(f'Interrogate: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}')
|
|
shared.state.end(jobid)
|
|
return (answer, annotated_image)
|
|
|
|
|
|
def batch(model_name, system_prompt, batch_files, batch_folder, batch_str, question, prompt, write, append, recursive, prefill=None, thinking_mode=False):
|
|
class BatchWriter:
|
|
def __init__(self, folder, mode='w'):
|
|
self.folder = folder
|
|
self.csv = None
|
|
self.file = None
|
|
self.mode = mode
|
|
|
|
def add(self, file, prompt):
|
|
txt_file = os.path.splitext(file)[0] + ".txt"
|
|
if self.mode == 'a':
|
|
prompt = '\n' + prompt
|
|
with open(os.path.join(self.folder, txt_file), self.mode, encoding='utf-8') as f:
|
|
f.write(prompt)
|
|
|
|
def close(self):
|
|
if self.file is not None:
|
|
self.file.close()
|
|
|
|
files = []
|
|
if batch_files is not None:
|
|
files += [f.name for f in batch_files]
|
|
if batch_folder is not None:
|
|
files += [f.name for f in batch_folder]
|
|
if batch_str is not None and len(batch_str) > 0 and os.path.exists(batch_str) and os.path.isdir(batch_str):
|
|
from modules.files_cache import list_files
|
|
files += list(list_files(batch_str, ext_filter=['.png', '.jpg', '.jpeg', '.webp', '.jxl'], recursive=recursive))
|
|
if len(files) == 0:
|
|
shared.log.warning('Interrogate batch: type=vlm no images')
|
|
return ''
|
|
jobid = shared.state.begin('Interrogate batch')
|
|
prompts = []
|
|
if write:
|
|
mode = 'w' if not append else 'a'
|
|
writer = BatchWriter(os.path.dirname(files[0]), mode=mode)
|
|
orig_offload = shared.opts.interrogate_offload
|
|
shared.opts.interrogate_offload = False
|
|
import rich.progress as rp
|
|
pbar = rp.Progress(rp.TextColumn('[cyan]Caption:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
|
with pbar:
|
|
task = pbar.add_task(total=len(files), description='starting...')
|
|
for file in files:
|
|
pbar.update(task, advance=1, description=file)
|
|
try:
|
|
if shared.state.interrupted:
|
|
break
|
|
image = Image.open(file)
|
|
result = interrogate(question, system_prompt, prompt, image, model_name, prefill, thinking_mode, quiet=True)
|
|
# Handle tuple return (text, annotated_image)
|
|
if isinstance(result, tuple):
|
|
prompt, annotated_img = result
|
|
# Optionally save annotated image
|
|
if annotated_img and write:
|
|
annotated_path = os.path.splitext(file)[0] + "_annotated.png"
|
|
annotated_img.save(annotated_path)
|
|
else:
|
|
prompt = result
|
|
prompts.append(prompt)
|
|
if write:
|
|
writer.add(file, prompt)
|
|
except Exception as e:
|
|
shared.log.error(f'Interrogate batch: {e}')
|
|
if write:
|
|
writer.close()
|
|
shared.opts.interrogate_offload = orig_offload
|
|
shared.state.end(jobid)
|
|
return '\n\n'.join(prompts)
|