refactor: rename interrogate module to caption

Move all caption-related modules from modules/interrogate/ to modules/caption/
for better naming consistency:
- Rename deepbooru, deepseek, joycaption, joytag, moondream3, openclip, tagger,
  vqa, vqa_detection, waifudiffusion modules
- Add new caption.py dispatcher module
- Remove old interrogate.py (functionality moved to caption.py)
pull/4613/head
CalamitousFelicitousness 2026-01-26 01:14:53 +00:00
parent 83fa8e39ba
commit 5183ebec58
13 changed files with 257 additions and 244 deletions

View File

@ -0,0 +1,48 @@
import time
from PIL import Image
from modules import shared
def caption(image):
if isinstance(image, list):
image = image[0] if len(image) > 0 else None
if isinstance(image, dict) and 'name' in image:
image = Image.open(image['name'])
if image is None:
shared.log.error('Caption: no image provided')
return ''
t0 = time.time()
if shared.opts.caption_default_type == 'OpenCLiP':
shared.log.info(f'Caption: type={shared.opts.caption_default_type} clip="{shared.opts.caption_openclip_model}" blip="{shared.opts.caption_openclip_blip_model}" mode="{shared.opts.caption_openclip_mode}"')
from modules.caption import openclip
openclip.load_captioner(clip_model=shared.opts.caption_openclip_model, blip_model=shared.opts.caption_openclip_blip_model)
openclip.update_caption_params()
prompt = openclip.caption(image, mode=shared.opts.caption_openclip_mode)
shared.log.debug(f'Caption: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
elif shared.opts.caption_default_type == 'Tagger':
shared.log.info(f'Caption: type={shared.opts.caption_default_type} model="{shared.opts.waifudiffusion_model}"')
from modules.caption import tagger
prompt = tagger.tag(
image=image,
model_name=shared.opts.waifudiffusion_model,
general_threshold=shared.opts.tagger_threshold,
character_threshold=shared.opts.waifudiffusion_character_threshold,
include_rating=shared.opts.tagger_include_rating,
exclude_tags=shared.opts.tagger_exclude_tags,
max_tags=shared.opts.tagger_max_tags,
sort_alpha=shared.opts.tagger_sort_alpha,
use_spaces=shared.opts.tagger_use_spaces,
escape_brackets=shared.opts.tagger_escape_brackets,
)
shared.log.debug(f'Caption: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
elif shared.opts.caption_default_type == 'VLM':
shared.log.info(f'Caption: type={shared.opts.caption_default_type} vlm="{shared.opts.caption_vlm_model}" prompt="{shared.opts.caption_vlm_prompt}"')
from modules.caption import vqa
prompt = vqa.caption(image=image, model_name=shared.opts.caption_vlm_model, question=shared.opts.caption_vlm_prompt, prompt=None, system_prompt=shared.opts.caption_vlm_system)
shared.log.debug(f'Caption: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
else:
shared.log.error(f'Caption: type="{shared.opts.caption_default_type}" unknown')
return ''

View File

@ -19,7 +19,7 @@ class DeepDanbooru:
if self.model is not None:
return
model_path = os.path.join(shared.opts.clip_models_path, "DeepDanbooru")
shared.log.debug(f'Interrogate load: module=DeepDanbooru folder="{model_path}"')
shared.log.debug(f'Caption load: module=DeepDanbooru folder="{model_path}"')
files = modelloader.load_models(
model_path=model_path,
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
@ -27,7 +27,7 @@ class DeepDanbooru:
download_name='model-resnet_custom_v3.pt',
)
from modules.interrogate.deepbooru_model import DeepDanbooruModel
from modules.caption.deepbooru_model import DeepDanbooruModel
self.model = DeepDanbooruModel()
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
self.model.eval()
@ -38,7 +38,7 @@ class DeepDanbooru:
self.model.to(devices.device)
def stop(self):
if shared.opts.interrogate_offload:
if shared.opts.caption_offload:
self.model.to(devices.cpu)
devices.torch_gc()

View File

@ -32,11 +32,11 @@ def load(repo: str):
"""Load DeepSeek VL2 model (experimental)."""
global vl_gpt, vl_chat_processor, loaded_repo # pylint: disable=global-statement
if not shared.cmd_opts.experimental:
shared.log.error(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}" is experimental-only')
shared.log.error(f'Caption: type=vlm model="DeepSeek VL2" repo="{repo}" is experimental-only')
return False
folder = os.path.join(paths.script_path, 'repositories', 'deepseek-vl2')
if not os.path.exists(folder):
shared.log.error(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}" deepseek-vl2 repo not found')
shared.log.error(f'Caption: type=vlm model="DeepSeek VL2" repo="{repo}" deepseek-vl2 repo not found')
return False
if vl_gpt is None or loaded_repo != repo:
sys.modules['attrdict'] = fake_attrdict
@ -53,7 +53,7 @@ def load(repo: str):
vl_gpt.to(dtype=devices.dtype)
vl_gpt.eval()
loaded_repo = repo
shared.log.info(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}"')
shared.log.info(f'Caption: type=vlm model="DeepSeek VL2" repo="{repo}"')
sd_models.move_model(vl_gpt, devices.device)
return True
@ -105,7 +105,7 @@ def predict(question, image, repo):
pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
bos_token_id=vl_chat_processor.tokenizer.bos_token_id,
eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
max_new_tokens=shared.opts.interrogate_vlm_max_length,
max_new_tokens=shared.opts.caption_vlm_max_length,
do_sample=False,
use_cache=True
)

View File

@ -64,7 +64,7 @@ def load(repo: str = None):
if llava_model is None or opts.repo != repo:
opts.repo = repo
llava_model = None
shared.log.info(f'Interrogate: type=vlm model="JoyCaption" {str(opts)}')
shared.log.info(f'Caption: type=vlm model="JoyCaption" {str(opts)}')
processor = AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
quant_args = model_quant.create_config(module='LLM')
llava_model = LlavaForConditionalGeneration.from_pretrained(
@ -92,7 +92,7 @@ def unload():
@torch.no_grad()
def predict(question: str, image, vqa_model: str = None) -> str:
opts.max_new_tokens = shared.opts.interrogate_vlm_max_length
opts.max_new_tokens = shared.opts.caption_vlm_max_length
load(vqa_model)
if len(question) < 2:
@ -121,7 +121,7 @@ def predict(question: str, image, vqa_model: str = None) -> str:
)[0]
generate_ids = generate_ids[inputs['input_ids'].shape[1]:] # Trim off the prompt
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Decode the caption
if shared.opts.interrogate_offload:
if shared.opts.caption_offload:
sd_models.move_model(llava_model, devices.cpu, force=True)
caption = caption.replace('\n\n', '\n').strip()
return caption

View File

@ -1044,7 +1044,7 @@ def load():
model.eval()
with open(os.path.join(folder, 'top_tags.txt'), 'r', encoding='utf8') as f:
tags = [line.strip() for line in f.readlines() if line.strip()]
shared.log.info(f'Interrogate: type=vlm model="JoyTag" repo="{MODEL_REPO}" tags={len(tags)}')
shared.log.info(f'Caption: type=vlm model="JoyTag" repo="{MODEL_REPO}" tags={len(tags)}')
sd_models.move_model(model, devices.device)
@ -1068,7 +1068,7 @@ def predict(image: Image.Image):
preds = model({'image': image_tensor})
tag_preds = preds['tags'].sigmoid().cpu()
scores = {tags[i]: tag_preds[0][i] for i in range(len(tags))}
if shared.opts.interrogate_score:
if shared.opts.tagger_show_scores:
predicted_tags = [f'{tag}:{score:.2f}' for tag, score in scores.items() if score > THRESHOLD]
else:
predicted_tags = [tag for tag, score in scores.items() if score > THRESHOLD]

View File

@ -7,11 +7,11 @@ import re
import transformers
from PIL import Image
from modules import shared, devices, sd_models
from modules.interrogate import vqa_detection
from modules.caption import vqa_detection
# Debug logging - function-based to avoid circular import
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None
def debug(*args, **kwargs):
if debug_enabled:
@ -30,12 +30,12 @@ def get_settings():
Moondream 3 accepts: temperature, top_p, max_tokens
"""
settings = {}
if shared.opts.interrogate_vlm_max_length > 0:
settings['max_tokens'] = shared.opts.interrogate_vlm_max_length
if shared.opts.interrogate_vlm_temperature > 0:
settings['temperature'] = shared.opts.interrogate_vlm_temperature
if shared.opts.interrogate_vlm_top_p > 0:
settings['top_p'] = shared.opts.interrogate_vlm_top_p
if shared.opts.caption_vlm_max_length > 0:
settings['max_tokens'] = shared.opts.caption_vlm_max_length
if shared.opts.caption_vlm_temperature > 0:
settings['temperature'] = shared.opts.caption_vlm_temperature
if shared.opts.caption_vlm_top_p > 0:
settings['top_p'] = shared.opts.caption_vlm_top_p
return settings if settings else None
@ -44,7 +44,7 @@ def load_model(repo: str):
global moondream3_model, loaded # pylint: disable=global-statement
if moondream3_model is None or loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
moondream3_model = None
moondream3_model = transformers.AutoModelForCausalLM.from_pretrained(
@ -84,7 +84,7 @@ def encode_image(image: Image.Image, cache_key: str = None):
Encoded image tensor
"""
if cache_key and cache_key in image_cache:
debug(f'VQA interrogate: handler=moondream3 using cached encoding for cache_key="{cache_key}"')
debug(f'VQA caption: handler=moondream3 using cached encoding for cache_key="{cache_key}"')
return image_cache[cache_key]
model = load_model(loaded)
@ -94,7 +94,7 @@ def encode_image(image: Image.Image, cache_key: str = None):
if cache_key:
image_cache[cache_key] = encoded
debug(f'VQA interrogate: handler=moondream3 cached encoding cache_key="{cache_key}" cache_size={len(image_cache)}')
debug(f'VQA caption: handler=moondream3 cached encoding cache_key="{cache_key}" cache_size={len(image_cache)}')
return encoded
@ -129,7 +129,7 @@ def query(image: Image.Image, question: str, repo: str, stream: bool = False,
if max_tokens is not None:
settings['max_tokens'] = max_tokens
debug(f'VQA interrogate: handler=moondream3 method=query question="{question}" stream={stream} settings={settings}')
debug(f'VQA caption: handler=moondream3 method=query question="{question}" stream={stream} settings={settings}')
# Use cached encoding if requested
if use_cache:
@ -150,12 +150,12 @@ def query(image: Image.Image, question: str, repo: str, stream: bool = False,
# Log response structure (for non-streaming)
if not stream:
if isinstance(response, dict):
debug(f'VQA interrogate: handler=moondream3 response_type=dict keys={list(response.keys())}')
debug(f'VQA caption: handler=moondream3 response_type=dict keys={list(response.keys())}')
if 'reasoning' in response:
reasoning_text = response['reasoning'].get('text', '')[:100] + '...' if len(response['reasoning'].get('text', '')) > 100 else response['reasoning'].get('text', '')
debug(f'VQA interrogate: handler=moondream3 reasoning="{reasoning_text}"')
debug(f'VQA caption: handler=moondream3 reasoning="{reasoning_text}"')
if 'answer' in response:
debug(f'VQA interrogate: handler=moondream3 answer="{response["answer"]}"')
debug(f'VQA caption: handler=moondream3 answer="{response["answer"]}"')
return response
@ -188,7 +188,7 @@ def caption(image: Image.Image, repo: str, length: str = 'normal', stream: bool
if max_tokens is not None:
settings['max_tokens'] = max_tokens
debug(f'VQA interrogate: handler=moondream3 method=caption length={length} stream={stream} settings={settings}')
debug(f'VQA caption: handler=moondream3 method=caption length={length} stream={stream} settings={settings}')
with devices.inference_context():
response = model.caption(
@ -200,7 +200,7 @@ def caption(image: Image.Image, repo: str, length: str = 'normal', stream: bool
# Log response structure (for non-streaming)
if not stream and isinstance(response, dict):
debug(f'VQA interrogate: handler=moondream3 response_type=dict keys={list(response.keys())}')
debug(f'VQA caption: handler=moondream3 response_type=dict keys={list(response.keys())}')
return response
@ -220,21 +220,21 @@ def point(image: Image.Image, object_name: str, repo: str):
"""
model = load_model(repo)
debug(f'VQA interrogate: handler=moondream3 method=point object_name="{object_name}"')
debug(f'VQA caption: handler=moondream3 method=point object_name="{object_name}"')
with devices.inference_context():
result = model.point(image, object_name)
debug(f'VQA interrogate: handler=moondream3 point_raw_result="{result}" type={type(result)}')
debug(f'VQA caption: handler=moondream3 point_raw_result="{result}" type={type(result)}')
if isinstance(result, dict):
debug(f'VQA interrogate: handler=moondream3 point_raw_result_keys={list(result.keys())}')
debug(f'VQA caption: handler=moondream3 point_raw_result_keys={list(result.keys())}')
points = vqa_detection.parse_points(result)
if points:
debug(f'VQA interrogate: handler=moondream3 point_result={len(points)} points found')
debug(f'VQA caption: handler=moondream3 point_result={len(points)} points found')
return points
debug('VQA interrogate: handler=moondream3 point_result=not found')
debug('VQA caption: handler=moondream3 point_result=not found')
return None
@ -257,17 +257,17 @@ def detect(image: Image.Image, object_name: str, repo: str, max_objects: int = 1
"""
model = load_model(repo)
debug(f'VQA interrogate: handler=moondream3 method=detect object_name="{object_name}" max_objects={max_objects}')
debug(f'VQA caption: handler=moondream3 method=detect object_name="{object_name}" max_objects={max_objects}')
with devices.inference_context():
result = model.detect(image, object_name)
debug(f'VQA interrogate: handler=moondream3 detect_raw_result="{result}" type={type(result)}')
debug(f'VQA caption: handler=moondream3 detect_raw_result="{result}" type={type(result)}')
if isinstance(result, dict):
debug(f'VQA interrogate: handler=moondream3 detect_raw_result_keys={list(result.keys())}')
debug(f'VQA caption: handler=moondream3 detect_raw_result_keys={list(result.keys())}')
detections = vqa_detection.parse_detections(result, object_name, max_objects)
debug(f'VQA interrogate: handler=moondream3 detect_result={len(detections)} objects found')
debug(f'VQA caption: handler=moondream3 detect_result={len(detections)} objects found')
return detections
@ -291,7 +291,7 @@ def predict(question: str, image: Image.Image, repo: str, model_name: str = None
Response string (detection data stored on VQA singleton instance.last_detection_data)
(or generator if stream=True for query/caption modes)
"""
debug(f'VQA interrogate: handler=moondream3 model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None} mode={mode} stream={stream}')
debug(f'VQA caption: handler=moondream3 model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None} mode={mode} stream={stream}')
# Clean question
question = question.replace('<', '').replace('>', '').replace('_', ' ') if question else ''
@ -331,7 +331,7 @@ def predict(question: str, image: Image.Image, repo: str, model_name: str = None
else:
mode = 'query'
debug(f'VQA interrogate: handler=moondream3 mode_selected={mode}')
debug(f'VQA caption: handler=moondream3 mode_selected={mode}')
# Dispatch to appropriate method
try:
@ -348,10 +348,10 @@ def predict(question: str, image: Image.Image, repo: str, model_name: str = None
object_name = re.sub(rf'\b{phrase}\b', '', object_name, flags=re.IGNORECASE)
object_name = re.sub(r'[?.!,]', '', object_name).strip()
object_name = re.sub(r'^\s*the\s+', '', object_name, flags=re.IGNORECASE)
debug(f'VQA interrogate: handler=moondream3 point_extracted_object="{object_name}"')
debug(f'VQA caption: handler=moondream3 point_extracted_object="{object_name}"')
result = point(image, object_name, repo)
if result:
from modules.interrogate import vqa
from modules.caption import vqa
vqa.get_instance().last_detection_data = {'points': result}
return vqa_detection.format_points_text(result)
return "Object not found"
@ -364,11 +364,11 @@ def predict(question: str, image: Image.Image, repo: str, model_name: str = None
object_name = re.sub(r'^\s*the\s+', '', object_name, flags=re.IGNORECASE)
if ' and ' in object_name.lower():
object_name = re.split(r'\s+and\s+', object_name, flags=re.IGNORECASE)[0].strip()
debug(f'VQA interrogate: handler=moondream3 detect_extracted_object="{object_name}"')
debug(f'VQA caption: handler=moondream3 detect_extracted_object="{object_name}"')
results = detect(image, object_name, repo, max_objects=kwargs.get('max_objects', 10))
if results:
from modules.interrogate import vqa
from modules.caption import vqa
vqa.get_instance().last_detection_data = {'detections': results}
return vqa_detection.format_detections_text(results)
return "No objects detected"
@ -377,7 +377,7 @@ def predict(question: str, image: Image.Image, repo: str, model_name: str = None
question = "Describe this image."
response = query(image, question, repo, stream=stream, use_cache=use_cache, reasoning=thinking_mode)
debug(f'VQA interrogate: handler=moondream3 response_before_clean="{response}"')
debug(f'VQA caption: handler=moondream3 response_before_clean="{response}"')
return response
except Exception as e:
@ -390,7 +390,7 @@ def clear_cache():
"""Clear image encoding cache."""
cache_size = len(image_cache)
image_cache.clear()
debug(f'VQA interrogate: handler=moondream3 cleared image cache cache_size_was={cache_size}')
debug(f'VQA caption: handler=moondream3 cleared image cache cache_size_was={cache_size}')
shared.log.debug(f'Moondream3: Cleared image cache ({cache_size} entries)')

View File

@ -8,7 +8,7 @@ from PIL import Image
from modules import devices, shared, errors, sd_models
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
# Per-request overrides for API calls
@ -19,7 +19,7 @@ def get_clip_setting(name):
"""Get CLIP setting with per-request override support.
Args:
name: Setting name without 'interrogate_clip_' prefix (e.g., 'min_flavors', 'max_length')
name: Setting name without 'caption_openclip_' prefix (e.g., 'min_flavors', 'max_length')
Returns:
Override value if set, otherwise the value from shared.opts
@ -28,7 +28,7 @@ def get_clip_setting(name):
value = _clip_overrides.get(name)
if value is not None:
return value
return getattr(shared.opts, f'interrogate_clip_{name}')
return getattr(shared.opts, f'caption_openclip_{name}')
def _apply_blip2_fix(model, processor):
@ -87,13 +87,14 @@ class BatchWriter:
self.file.close()
def update_interrogate_params():
def update_caption_params():
if ci is not None:
ci.caption_max_length = get_clip_setting('max_length')
ci.chunk_size = get_clip_setting('chunk_size')
ci.flavor_intermediate_count = get_clip_setting('flavor_count')
ci.clip_offload = shared.opts.interrogate_offload
ci.caption_offload = shared.opts.interrogate_offload
ci.clip_offload = shared.opts.caption_offload
ci.caption_offload = shared.opts.caption_offload
def get_clip_models():
@ -104,12 +105,12 @@ def refresh_clip_models():
global clip_models # pylint: disable=global-statement
import open_clip
models = sorted(open_clip.list_pretrained())
shared.log.debug(f'Interrogate: pkg=openclip version={open_clip.__version__} models={len(models)}')
shared.log.debug(f'Caption: pkg=openclip version={open_clip.__version__} models={len(models)}')
clip_models = ['/'.join(x) for x in models]
return clip_models
def load_interrogator(clip_model, blip_model):
def load_captioner(clip_model, blip_model):
from installer import install
install('clip_interrogator==0.6.0')
import clip_interrogator
@ -120,20 +121,21 @@ def load_interrogator(clip_model, blip_model):
device = devices.get_optimal_device()
cache_path = shared.opts.clip_models_path
shared.log.info(f'CLIP load: clip="{clip_model}" blip="{blip_model}" device={device}')
debug_log(f'CLIP load: cache_path="{cache_path}" max_length={shared.opts.interrogate_clip_max_length} chunk_size={shared.opts.interrogate_clip_chunk_size} flavor_count={shared.opts.interrogate_clip_flavor_count} offload={shared.opts.interrogate_offload}')
interrogator_config = clip_interrogator.Config(
debug_log(f'CLIP load: cache_path="{cache_path}" max_length={shared.opts.caption_openclip_max_length} chunk_size={shared.opts.caption_openclip_chunk_size} flavor_count={shared.opts.caption_openclip_flavor_count} offload={shared.opts.caption_offload}')
captioner_config = clip_interrogator.Config(
device=device,
cache_path=cache_path,
clip_model_name=clip_model,
caption_model_name=blip_model,
quiet=True,
caption_max_length=shared.opts.interrogate_clip_max_length,
chunk_size=shared.opts.interrogate_clip_chunk_size,
flavor_intermediate_count=shared.opts.interrogate_clip_flavor_count,
clip_offload=shared.opts.interrogate_offload,
caption_offload=shared.opts.interrogate_offload,
caption_max_length=shared.opts.caption_openclip_max_length,
chunk_size=shared.opts.caption_openclip_chunk_size,
flavor_intermediate_count=shared.opts.caption_openclip_flavor_count,
clip_offload=shared.opts.caption_offload,
caption_offload=shared.opts.caption_offload,
)
ci = clip_interrogator.Interrogator(interrogator_config)
ci = clip_interrogator.Interrogator(captioner_config)
if blip_model.startswith('blip2-'):
_apply_blip2_fix(ci.caption_model, ci.caption_processor)
shared.log.debug(f'CLIP load: time={time.time()-t0:.2f}')
@ -145,12 +147,14 @@ def load_interrogator(clip_model, blip_model):
ci.config.clip_model_name = clip_model
ci.config.clip_model = None
ci.load_clip_model()
ci.clip_offloaded = True # Reset flag so _prepare_clip() will move model to device
if blip_model != ci.config.caption_model_name:
shared.log.info(f'CLIP load: blip="{blip_model}" reloading')
debug_log(f'CLIP load: previous blip="{ci.config.caption_model_name}"')
ci.config.caption_model_name = blip_model
ci.config.caption_model = None
ci.load_caption_model()
ci.caption_offloaded = True # Reset flag so _prepare_caption() will move model to device
if blip_model.startswith('blip2-'):
_apply_blip2_fix(ci.caption_model, ci.caption_processor)
shared.log.debug(f'CLIP load: time={time.time()-t0:.2f}')
@ -159,7 +163,7 @@ def load_interrogator(clip_model, blip_model):
def unload_clip_model():
if ci is not None and shared.opts.interrogate_offload:
if ci is not None and shared.opts.caption_offload:
shared.log.debug('CLIP unload: offloading models to CPU')
sd_models.move_model(ci.caption_model, devices.cpu)
sd_models.move_model(ci.clip_model, devices.cpu)
@ -169,7 +173,7 @@ def unload_clip_model():
debug_log('CLIP unload: complete')
def interrogate(image, mode, caption=None):
def caption(image, mode, base_caption=None):
if isinstance(image, list):
image = image[0] if len(image) > 0 else None
if isinstance(image, dict) and 'name' in image:
@ -180,15 +184,17 @@ def interrogate(image, mode, caption=None):
t0 = time.time()
min_flavors = get_clip_setting('min_flavors')
max_flavors = get_clip_setting('max_flavors')
debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={caption is not None} min_flavors={min_flavors} max_flavors={max_flavors}')
debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={base_caption is not None} min_flavors={min_flavors} max_flavors={max_flavors}')
# NOTE: Method names like .interrogate(), .interrogate_classic(), etc. come from the external
# clip-interrogator library (https://github.com/pharmapsychotic/clip-interrogator) and cannot be renamed.
if mode == 'best':
prompt = ci.interrogate(image, caption=caption, min_flavors=min_flavors, max_flavors=max_flavors)
prompt = ci.interrogate(image, caption=base_caption, min_flavors=min_flavors, max_flavors=max_flavors)
elif mode == 'caption':
prompt = ci.generate_caption(image) if caption is None else caption
prompt = ci.generate_caption(image) if base_caption is None else base_caption
elif mode == 'classic':
prompt = ci.interrogate_classic(image, caption=caption, max_flavors=max_flavors)
prompt = ci.interrogate_classic(image, caption=base_caption, max_flavors=max_flavors)
elif mode == 'fast':
prompt = ci.interrogate_fast(image, caption=caption, max_flavors=max_flavors)
prompt = ci.interrogate_fast(image, caption=base_caption, max_flavors=max_flavors)
elif mode == 'negative':
prompt = ci.interrogate_negative(image, max_flavors=max_flavors)
else:
@ -197,9 +203,10 @@ def interrogate(image, mode, caption=None):
return prompt
def interrogate_image(image, clip_model, blip_model, mode, overrides=None):
def caption_image(image, clip_model, blip_model, mode, overrides=None):
global _clip_overrides # pylint: disable=global-statement
jobid = shared.state.begin('Interrogate CLiP')
jobid = shared.state.begin('Caption CLiP')
t0 = time.time()
shared.log.info(f'CLIP: mode="{mode}" clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}')
if overrides:
@ -211,17 +218,19 @@ def interrogate_image(image, clip_model, blip_model, mode, overrides=None):
from modules.sd_models import apply_balanced_offload # prevent circular import
apply_balanced_offload(shared.sd_model)
debug_log('CLIP: applied balanced offload to sd_model')
load_interrogator(clip_model, blip_model)
# Apply overrides to loaded interrogator
update_interrogate_params()
load_captioner(clip_model, blip_model)
# Apply overrides to loaded captioner
update_caption_params()
image = image.convert('RGB')
prompt = interrogate(image, mode)
prompt = caption(image, mode)
if shared.opts.caption_offload:
unload_clip_model()
devices.torch_gc()
shared.log.debug(f'CLIP: complete time={time.time()-t0:.2f}')
except Exception as e:
prompt = f"Exception {type(e)}"
shared.log.error(f'CLIP: {e}')
errors.display(e, 'Interrogate')
errors.display(e, 'Caption')
finally:
# Clear per-request overrides
_clip_overrides = None
@ -229,7 +238,8 @@ def interrogate_image(image, clip_model, blip_model, mode, overrides=None):
return prompt
def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_model, mode, write, append, recursive):
def caption_batch(batch_files, batch_folder, batch_str, clip_model, blip_model, mode, write, append, recursive):
files = []
if batch_files is not None:
files += [f.name for f in batch_files]
@ -244,10 +254,10 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
t0 = time.time()
shared.log.info(f'CLIP batch: mode="{mode}" images={len(files)} clip="{clip_model}" blip="{blip_model}" write={write} append={append}')
debug_log(f'CLIP batch: recursive={recursive} files={files[:5]}{"..." if len(files) > 5 else ""}')
jobid = shared.state.begin('Interrogate batch')
jobid = shared.state.begin('Caption batch')
prompts = []
load_interrogator(clip_model, blip_model)
load_captioner(clip_model, blip_model)
if write:
file_mode = 'w' if not append else 'a'
writer = BatchWriter(os.path.dirname(files[0]), mode=file_mode)
@ -263,7 +273,7 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
shared.log.info('CLIP batch: interrupted')
break
image = Image.open(file).convert('RGB')
prompt = interrogate(image, mode)
prompt = caption(image, mode)
prompts.append(prompt)
if write:
writer.add(file, prompt)
@ -278,10 +288,11 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
return '\n\n'.join(prompts)
def analyze_image(image, clip_model, blip_model):
t0 = time.time()
shared.log.info(f'CLIP analyze: clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}')
load_interrogator(clip_model, blip_model)
load_captioner(clip_model, blip_model)
image = image.convert('RGB')
image_features = ci.image_to_features(image)
debug_log(f'CLIP analyze: features shape={image_features.shape if hasattr(image_features, "shape") else "unknown"}')

View File

@ -8,7 +8,7 @@ DEEPBOORU_MODEL = "DeepBooru"
def get_models() -> list:
"""Return combined list: DeepBooru + WaifuDiffusion models."""
from modules.interrogate import waifudiffusion
from modules.caption import waifudiffusion
return [DEEPBOORU_MODEL] + waifudiffusion.get_models()
@ -25,16 +25,16 @@ def is_deepbooru(model_name: str) -> bool:
def load_model(model_name: str) -> bool:
"""Load appropriate backend."""
if is_deepbooru(model_name):
from modules.interrogate import deepbooru
from modules.caption import deepbooru
return deepbooru.load_model()
else:
from modules.interrogate import waifudiffusion
from modules.caption import waifudiffusion
return waifudiffusion.load_model(model_name)
def unload_model():
"""Unload both backends to ensure memory is freed."""
from modules.interrogate import deepbooru, waifudiffusion
from modules.caption import deepbooru, waifudiffusion
deepbooru.unload_model()
waifudiffusion.unload_model()
@ -54,10 +54,10 @@ def tag(image, model_name: str = None, **kwargs) -> str:
model_name = shared.opts.waifudiffusion_model
if is_deepbooru(model_name):
from modules.interrogate import deepbooru
from modules.caption import deepbooru
return deepbooru.tag(image, **kwargs)
else:
from modules.interrogate import waifudiffusion
from modules.caption import waifudiffusion
return waifudiffusion.tag(image, model_name=model_name, **kwargs)
@ -72,8 +72,8 @@ def batch(model_name: str, **kwargs) -> str:
Combined tag results
"""
if is_deepbooru(model_name):
from modules.interrogate import deepbooru
from modules.caption import deepbooru
return deepbooru.batch(model_name=model_name, **kwargs)
else:
from modules.interrogate import waifudiffusion
from modules.caption import waifudiffusion
return waifudiffusion.batch(model_name=model_name, **kwargs)

View File

@ -9,11 +9,11 @@ import transformers
import transformers.dynamic_module_utils
from PIL import Image
from modules import shared, devices, errors, model_quant, sd_models, sd_models_compile, ui_symbols
from modules.interrogate import vqa_detection
from modules.caption import vqa_detection
# Debug logging - function-based to avoid circular import
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None
def debug(*args, **kwargs):
if debug_enabled:
@ -265,7 +265,7 @@ def keep_think_block_open(text_prompt: str) -> str:
while end_close < len(text_prompt) and text_prompt[end_close] in ('\r', '\n'):
end_close += 1
trimmed_prompt = text_prompt[:close_index] + text_prompt[end_close:]
debug('VQA interrogate: keep_think_block_open applied to prompt segment near assistant reply')
debug('VQA caption: keep_think_block_open applied to prompt segment near assistant reply')
return trimmed_prompt
@ -348,7 +348,7 @@ def get_keep_thinking():
overrides = _get_overrides()
if overrides.get('keep_thinking') is not None:
return overrides['keep_thinking']
return shared.opts.interrogate_vlm_keep_thinking
return shared.opts.caption_vlm_keep_thinking
def get_keep_prefill():
@ -356,7 +356,7 @@ def get_keep_prefill():
overrides = _get_overrides()
if overrides.get('keep_prefill') is not None:
return overrides['keep_prefill']
return shared.opts.interrogate_vlm_keep_prefill
return shared.opts.caption_vlm_keep_prefill
def get_kwargs():
@ -370,12 +370,12 @@ def get_kwargs():
overrides = _get_overrides()
# Get base values from settings, apply overrides if provided
max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.interrogate_vlm_max_length
do_sample = overrides.get('do_sample') if overrides.get('do_sample') is not None else shared.opts.interrogate_vlm_do_sample
num_beams = overrides.get('num_beams') if overrides.get('num_beams') is not None else shared.opts.interrogate_vlm_num_beams
temperature = overrides.get('temperature') if overrides.get('temperature') is not None else shared.opts.interrogate_vlm_temperature
top_k = overrides.get('top_k') if overrides.get('top_k') is not None else shared.opts.interrogate_vlm_top_k
top_p = overrides.get('top_p') if overrides.get('top_p') is not None else shared.opts.interrogate_vlm_top_p
max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.caption_vlm_max_length
do_sample = overrides.get('do_sample') if overrides.get('do_sample') is not None else shared.opts.caption_vlm_do_sample
num_beams = overrides.get('num_beams') if overrides.get('num_beams') is not None else shared.opts.caption_vlm_num_beams
temperature = overrides.get('temperature') if overrides.get('temperature') is not None else shared.opts.caption_vlm_temperature
top_k = overrides.get('top_k') if overrides.get('top_k') is not None else shared.opts.caption_vlm_top_k
top_p = overrides.get('top_p') if overrides.get('top_p') is not None else shared.opts.caption_vlm_top_p
kwargs = {
'max_new_tokens': max_tokens,
@ -419,7 +419,7 @@ class VQA:
def load(self, model_name: str = None):
"""Load VLM model into memory for the specified model name."""
model_name = model_name or shared.opts.interrogate_vlm_model
model_name = model_name or shared.opts.caption_vlm_model
if not model_name:
shared.log.warning('VQA load: no model specified')
return
@ -430,7 +430,7 @@ class VQA:
shared.log.debug(f'VQA load: pre-loading model="{model_name}" repo="{repo}"')
# Dispatch to appropriate loader (same logic as interrogate)
# Dispatch to appropriate loader (same logic as caption)
repo_lower = repo.lower()
if 'qwen' in repo_lower or 'torii' in repo_lower or 'mimo' in repo_lower:
self._load_qwen(repo)
@ -459,22 +459,22 @@ class VQA:
elif 'fastvlm' in repo_lower:
self._load_fastvlm(repo)
elif 'moondream3' in repo_lower:
from modules.interrogate import moondream3
from modules.caption import moondream3
moondream3.load_model(repo)
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
elif 'joytag' in repo_lower:
from modules.interrogate import joytag
from modules.caption import joytag
joytag.load()
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
elif 'joycaption' in repo_lower:
from modules.interrogate import joycaption
from modules.caption import joycaption
joycaption.load(repo)
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
elif 'deepseek' in repo_lower:
from modules.interrogate import deepseek
from modules.caption import deepseek
deepseek.load(repo)
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
@ -488,7 +488,7 @@ class VQA:
def _load_fastvlm(self, repo: str):
"""Load FastVLM model and tokenizer."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
quant_args = model_quant.create_config(module='LLM')
self.model = None
self.processor = transformers.AutoTokenizer.from_pretrained(repo, trust_remote_code=True, cache_dir=shared.opts.hfcache_dir)
@ -503,7 +503,7 @@ class VQA:
devices.torch_gc()
def _fastvlm(self, question: str, image: Image.Image, repo: str, model_name: str = None):
debug(f'VQA interrogate: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}')
debug(f'VQA caption: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}')
self._load_fastvlm(repo)
sd_models.move_model(self.model, devices.device)
if len(question) < 2:
@ -534,7 +534,7 @@ class VQA:
def _load_qwen(self, repo: str):
"""Load Qwen VL model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
self.model = None
if 'Qwen3-VL' in repo or 'Qwen3VL' in repo:
cls_name = transformers.Qwen3VLForConditionalGeneration
@ -562,10 +562,10 @@ class VQA:
sd_models.move_model(self.model, devices.device)
# Get model class name for logging
cls_name = self.model.__class__.__name__
debug(f'VQA interrogate: handler=qwen model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
debug(f'VQA caption: handler=qwen model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
question = question.replace('<', '').replace('>', '').replace('_', ' ')
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
system_prompt = system_prompt or shared.opts.caption_vlm_system
conversation = [
{
"role": "system",
@ -593,9 +593,9 @@ class VQA:
use_prefill = len(prefill_text) > 0
if debug_enabled:
debug(f'VQA interrogate: handler=qwen conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA interrogate: handler=qwen full_conversation={truncate_b64_in_conversation(conversation)}')
debug(f'VQA interrogate: handler=qwen is_thinking={is_thinking} thinking_mode={thinking_mode} prefill="{prefill_text}"')
debug(f'VQA caption: handler=qwen conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA caption: handler=qwen full_conversation={truncate_b64_in_conversation(conversation)}')
debug(f'VQA caption: handler=qwen is_thinking={is_thinking} thinking_mode={thinking_mode} prefill="{prefill_text}"')
# Generate base prompt using template
# Qwen-Thinking template automatically adds "<|im_start|>assistant\n<think>\n" when add_generation_prompt=True
@ -605,7 +605,7 @@ class VQA:
add_generation_prompt=True,
)
except (TypeError, ValueError) as e:
debug(f'VQA interrogate: handler=qwen chat_template fallback add_generation_prompt=True: {e}')
debug(f'VQA caption: handler=qwen chat_template fallback add_generation_prompt=True: {e}')
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
# Manually handle thinking tags and prefill
@ -627,23 +627,23 @@ class VQA:
text_prompt += prefill_text
if debug_enabled:
debug(f'VQA interrogate: handler=qwen text_prompt="{text_prompt}"')
debug(f'VQA caption: handler=qwen text_prompt="{text_prompt}"')
inputs = self.processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
gen_kwargs = get_kwargs()
debug(f'VQA interrogate: handler=qwen generation_kwargs={gen_kwargs} input_ids_shape={inputs.input_ids.shape}')
debug(f'VQA caption: handler=qwen generation_kwargs={gen_kwargs} input_ids_shape={inputs.input_ids.shape}')
output_ids = self.model.generate(
**inputs,
**gen_kwargs,
)
debug(f'VQA interrogate: handler=qwen output_ids_shape={output_ids.shape}')
debug(f'VQA caption: handler=qwen output_ids_shape={output_ids.shape}')
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if debug_enabled:
debug(f'VQA interrogate: handler=qwen response_before_clean="{response}"')
debug(f'VQA caption: handler=qwen response_before_clean="{response}"')
# Clean up thinking tags
# Note: <think> is in the prompt, not the response - only </think> appears in generated output
if len(response) > 0:
@ -672,7 +672,7 @@ class VQA:
def _load_gemma(self, repo: str):
"""Load Gemma 3 model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
self.model = None
if '3n' in repo:
cls = transformers.Gemma3nForConditionalGeneration # pylint: disable=no-member
@ -696,10 +696,10 @@ class VQA:
sd_models.move_model(self.model, devices.device)
# Get model class name for logging
cls_name = self.model.__class__.__name__
debug(f'VQA interrogate: handler=gemma model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
debug(f'VQA caption: handler=gemma model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
question = question.replace('<', '').replace('>', '').replace('_', ' ')
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
system_prompt = system_prompt or shared.opts.caption_vlm_system
system_content = []
if system_prompt is not None and len(system_prompt) > 4:
@ -726,14 +726,14 @@ class VQA:
"role": "assistant",
"content": [{"type": "text", "text": prefill_text}],
})
debug(f'VQA interrogate: handler=gemma prefill="{prefill_text}"')
debug(f'VQA caption: handler=gemma prefill="{prefill_text}"')
else:
debug('VQA interrogate: handler=gemma prefill disabled (empty), relying on add_generation_prompt')
debug('VQA caption: handler=gemma prefill disabled (empty), relying on add_generation_prompt')
if debug_enabled:
debug(f'VQA interrogate: handler=gemma conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA interrogate: handler=gemma full_conversation={truncate_b64_in_conversation(conversation)}')
debug(f'VQA caption: handler=gemma conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA caption: handler=gemma full_conversation={truncate_b64_in_conversation(conversation)}')
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
debug(f'VQA interrogate: handler=gemma template_mode={debug_prefill_mode}')
debug(f'VQA caption: handler=gemma template_mode={debug_prefill_mode}')
try:
if use_prefill:
text_prompt = self.processor.apply_chat_template(
@ -749,7 +749,7 @@ class VQA:
tokenize=False,
)
except (TypeError, ValueError) as e:
debug(f'VQA interrogate: handler=gemma chat_template fallback add_generation_prompt=True: {e}')
debug(f'VQA caption: handler=gemma chat_template fallback add_generation_prompt=True: {e}')
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=True,
@ -758,7 +758,7 @@ class VQA:
if use_prefill and use_thinking:
text_prompt = keep_think_block_open(text_prompt)
if debug_enabled:
debug(f'VQA interrogate: handler=gemma text_prompt="{text_prompt}"')
debug(f'VQA caption: handler=gemma text_prompt="{text_prompt}"')
inputs = self.processor(
text=[text_prompt],
images=[image],
@ -767,17 +767,17 @@ class VQA:
).to(device=devices.device, dtype=devices.dtype)
input_len = inputs["input_ids"].shape[-1]
gen_kwargs = get_kwargs()
debug(f'VQA interrogate: handler=gemma generation_kwargs={gen_kwargs} input_len={input_len}')
debug(f'VQA caption: handler=gemma generation_kwargs={gen_kwargs} input_len={input_len}')
with devices.inference_context():
generation = self.model.generate(
**inputs,
**gen_kwargs,
)
debug(f'VQA interrogate: handler=gemma output_ids_shape={generation.shape}')
debug(f'VQA caption: handler=gemma output_ids_shape={generation.shape}')
generation = generation[0][input_len:]
response = self.processor.decode(generation, skip_special_tokens=True)
if debug_enabled:
debug(f'VQA interrogate: handler=gemma response_before_clean="{response}"')
debug(f'VQA caption: handler=gemma response_before_clean="{response}"')
# Clean up thinking tags (if any remain)
if get_keep_thinking():
@ -798,7 +798,7 @@ class VQA:
def _load_paligemma(self, repo: str):
"""Load PaliGemma model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
self.processor = transformers.PaliGemmaProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.model = None
self.model = transformers.PaliGemmaForConditionalGeneration.from_pretrained(
@ -827,7 +827,7 @@ class VQA:
def _load_ovis(self, repo: str):
"""Load Ovis model (requires flash-attn)."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.AutoModelForCausalLM.from_pretrained(
repo,
@ -843,7 +843,7 @@ class VQA:
try:
import flash_attn # pylint: disable=unused-import
except Exception:
shared.log.error(f'Interrogate: vlm="{repo}" flash-attn is not available')
shared.log.error(f'Caption: vlm="{repo}" flash-attn is not available')
return ''
self._load_ovis(repo)
sd_models.move_model(self.model, devices.device)
@ -875,7 +875,7 @@ class VQA:
def _load_smol(self, repo: str):
"""Load SmolVLM model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
self.model = None
quant_args = model_quant.create_config(module='LLM')
self.model = transformers.AutoModelForVision2Seq.from_pretrained(
@ -895,10 +895,10 @@ class VQA:
sd_models.move_model(self.model, devices.device)
# Get model class name for logging
cls_name = self.model.__class__.__name__
debug(f'VQA interrogate: handler=smol model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
debug(f'VQA caption: handler=smol model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
question = question.replace('<', '').replace('>', '').replace('_', ' ')
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
system_prompt = system_prompt or shared.opts.caption_vlm_system
conversation = [
{
"role": "system",
@ -924,14 +924,14 @@ class VQA:
"role": "assistant",
"content": [{"type": "text", "text": prefill_text}],
})
debug(f'VQA interrogate: handler=smol prefill="{prefill_text}"')
debug(f'VQA caption: handler=smol prefill="{prefill_text}"')
else:
debug('VQA interrogate: handler=smol prefill disabled (empty), relying on add_generation_prompt')
debug('VQA caption: handler=smol prefill disabled (empty), relying on add_generation_prompt')
if debug_enabled:
debug(f'VQA interrogate: handler=smol conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA interrogate: handler=smol full_conversation={truncate_b64_in_conversation(conversation)}')
debug(f'VQA caption: handler=smol conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA caption: handler=smol full_conversation={truncate_b64_in_conversation(conversation)}')
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
debug(f'VQA interrogate: handler=smol template_mode={debug_prefill_mode}')
debug(f'VQA caption: handler=smol template_mode={debug_prefill_mode}')
try:
if use_prefill:
text_prompt = self.processor.apply_chat_template(
@ -942,24 +942,24 @@ class VQA:
else:
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
except (TypeError, ValueError) as e:
debug(f'VQA interrogate: handler=smol chat_template fallback add_generation_prompt=True: {e}')
debug(f'VQA caption: handler=smol chat_template fallback add_generation_prompt=True: {e}')
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
if use_prefill and use_thinking:
text_prompt = keep_think_block_open(text_prompt)
if debug_enabled:
debug(f'VQA interrogate: handler=smol text_prompt="{text_prompt}"')
debug(f'VQA caption: handler=smol text_prompt="{text_prompt}"')
inputs = self.processor(text=text_prompt, images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
gen_kwargs = get_kwargs()
debug(f'VQA interrogate: handler=smol generation_kwargs={gen_kwargs}')
debug(f'VQA caption: handler=smol generation_kwargs={gen_kwargs}')
output_ids = self.model.generate(
**inputs,
**gen_kwargs,
)
debug(f'VQA interrogate: handler=smol output_ids_shape={output_ids.shape}')
debug(f'VQA caption: handler=smol output_ids_shape={output_ids.shape}')
response = self.processor.batch_decode(output_ids, skip_special_tokens=True)
if debug_enabled:
debug(f'VQA interrogate: handler=smol response_before_clean="{response}"')
debug(f'VQA caption: handler=smol response_before_clean="{response}"')
# Clean up thinking tags
if len(response) > 0:
@ -981,7 +981,7 @@ class VQA:
def _load_git(self, repo: str):
"""Load Microsoft GIT model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.GitForCausalLM.from_pretrained(
repo,
@ -1011,7 +1011,7 @@ class VQA:
def _load_blip(self, repo: str):
"""Load Salesforce BLIP model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.BlipForQuestionAnswering.from_pretrained(
repo,
@ -1035,7 +1035,7 @@ class VQA:
def _load_vilt(self, repo: str):
"""Load ViLT model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.ViltForQuestionAnswering.from_pretrained(
repo,
@ -1061,7 +1061,7 @@ class VQA:
def _load_pix(self, repo: str):
"""Load Pix2Struct model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.Pix2StructForConditionalGeneration.from_pretrained(
repo,
@ -1087,7 +1087,7 @@ class VQA:
def _load_moondream(self, repo: str):
"""Load Moondream 2 model and tokenizer."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
shared.log.debug(f'Caption load: vlm="{repo}"')
self.model = None
self.model = transformers.AutoModelForCausalLM.from_pretrained(
repo,
@ -1102,7 +1102,7 @@ class VQA:
devices.torch_gc()
def _moondream(self, question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False):
debug(f'VQA interrogate: handler=moondream model_name="{model_name}" repo="{repo}" question="{question}" thinking_mode={thinking_mode}')
debug(f'VQA caption: handler=moondream model_name="{model_name}" repo="{repo}" question="{question}" thinking_mode={thinking_mode}')
self._load_moondream(repo)
sd_models.move_model(self.model, devices.device)
question = question.replace('<', '').replace('>', '').replace('_', ' ')
@ -1117,9 +1117,9 @@ class VQA:
target = question[9:].strip() if question.lower().startswith('point at ') else ''
if not target:
return "Please specify an object to locate"
debug(f'VQA interrogate: handler=moondream method=point target="{target}"')
debug(f'VQA caption: handler=moondream method=point target="{target}"')
result = self.model.point(image, target)
debug(f'VQA interrogate: handler=moondream point_raw_result={result}')
debug(f'VQA caption: handler=moondream point_raw_result={result}')
points = vqa_detection.parse_points(result)
if points:
self.last_detection_data = {'points': points}
@ -1129,35 +1129,35 @@ class VQA:
target = question[7:].strip() if question.lower().startswith('detect ') else ''
if not target:
return "Please specify an object to detect"
debug(f'VQA interrogate: handler=moondream method=detect target="{target}"')
debug(f'VQA caption: handler=moondream method=detect target="{target}"')
result = self.model.detect(image, target)
debug(f'VQA interrogate: handler=moondream detect_raw_result={result}')
debug(f'VQA caption: handler=moondream detect_raw_result={result}')
detections = vqa_detection.parse_detections(result, target)
if detections:
self.last_detection_data = {'detections': detections}
return vqa_detection.format_detections_text(detections, include_confidence=False)
return "No objects detected"
elif question == 'DETECT_GAZE' or question.lower() == 'detect gaze':
debug('VQA interrogate: handler=moondream method=detect_gaze')
debug('VQA caption: handler=moondream method=detect_gaze')
faces = self.model.detect(image, "face")
debug(f'VQA interrogate: handler=moondream detect_gaze faces={faces}')
debug(f'VQA caption: handler=moondream detect_gaze faces={faces}')
if faces.get('objects'):
eye_x, eye_y = vqa_detection.calculate_eye_position(faces['objects'][0])
result = self.model.detect_gaze(image, eye=(eye_x, eye_y))
debug(f'VQA interrogate: handler=moondream detect_gaze result={result}')
debug(f'VQA caption: handler=moondream detect_gaze result={result}')
if result.get('gaze'):
gaze = result['gaze']
self.last_detection_data = {'points': [(gaze['x'], gaze['y'])]}
return f"Gaze direction: ({gaze['x']:.3f}, {gaze['y']:.3f})"
return "No face/gaze detected"
else:
debug(f'VQA interrogate: handler=moondream method=query question="{question}" reasoning={thinking_mode}')
debug(f'VQA caption: handler=moondream method=query question="{question}" reasoning={thinking_mode}')
result = self.model.query(image, question, reasoning=thinking_mode)
response = result['answer']
debug(f'VQA interrogate: handler=moondream query_result keys={list(result.keys()) if isinstance(result, dict) else "not dict"}')
debug(f'VQA caption: handler=moondream query_result keys={list(result.keys()) if isinstance(result, dict) else "not dict"}')
if thinking_mode and 'reasoning' in result:
reasoning_text = result['reasoning'].get('text', '') if isinstance(result['reasoning'], dict) else str(result['reasoning'])
debug(f'VQA interrogate: handler=moondream reasoning_text="{reasoning_text[:100]}..."')
debug(f'VQA caption: handler=moondream reasoning_text="{reasoning_text[:100]}..."')
if get_keep_thinking():
response = f"Reasoning:\n{reasoning_text}\n\nAnswer:\n{response}"
# When keep_thinking is False, just use the answer (reasoning is discarded)
@ -1183,7 +1183,7 @@ class VQA:
effective_revision = revision_from_repo
if self.model is None or self.loaded != cache_key:
shared.log.debug(f'Interrogate load: vlm="{repo_name}" revision="{effective_revision}" path="{shared.opts.hfcache_dir}"')
shared.log.debug(f'Caption load: vlm="{repo_name}" revision="{effective_revision}" path="{shared.opts.hfcache_dir}"')
transformers.dynamic_module_utils.get_imports = get_imports
self.model = None
quant_args = model_quant.create_config(module='LLM')
@ -1213,7 +1213,7 @@ class VQA:
pixel_values = inputs['pixel_values'].to(devices.device, devices.dtype)
# Florence-2 requires beam search, not sampling - sampling causes probability tensor errors
overrides = _get_overrides()
max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.interrogate_vlm_max_length
max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.caption_vlm_max_length
with devices.inference_context():
generated_ids = self.model.generate(
input_ids=input_ids,
@ -1263,9 +1263,9 @@ class VQA:
response = return_dict["prediction"] # the text format answer
return response
def interrogate(self, question: str = '', system_prompt: str = None, prompt: str = None, image: Image.Image = None, model_name: str = None, prefill: str = None, thinking_mode: bool = None, quiet: bool = False, generation_kwargs: dict = None) -> str:
def caption(self, question: str = '', system_prompt: str = None, prompt: str = None, image: Image.Image = None, model_name: str = None, prefill: str = None, thinking_mode: bool = None, quiet: bool = False, generation_kwargs: dict = None) -> str:
"""
Main entry point for VQA interrogation. Returns string answer.
Main entry point for VQA captioning. Returns string answer.
Detection data stored in self.last_detection_data for annotated image creation.
Args:
@ -1283,11 +1283,11 @@ class VQA:
self.last_annotated_image = None
self.last_detection_data = None
self._generation_overrides = generation_kwargs # Set per-request overrides
jobid = shared.state.begin('Interrogate LLM')
jobid = shared.state.begin('Caption LLM')
t0 = time.time()
model_name = model_name or shared.opts.interrogate_vlm_model
model_name = model_name or shared.opts.caption_vlm_model
prefill = vlm_prefill if prefill is None else prefill # Use provided prefill when specified
thinking_mode = shared.opts.interrogate_vlm_thinking_mode if thinking_mode is None else thinking_mode # Resolve from settings if not specified
thinking_mode = shared.opts.caption_vlm_thinking_mode if thinking_mode is None else thinking_mode # Resolve from settings if not specified
if isinstance(image, list):
image = image[0] if len(image) > 0 else None
if isinstance(image, dict) and 'name' in image:
@ -1298,7 +1298,7 @@ class VQA:
if image.mode != 'RGB':
image = image.convert('RGB')
if image is None:
shared.log.error(f'VQA interrogate: model="{model_name}" error="No input image provided"')
shared.log.error(f'VQA caption: model="{model_name}" error="No input image provided"')
shared.state.end(jobid)
return 'Error: No input image provided. Please upload or select an image.'
@ -1306,7 +1306,7 @@ class VQA:
if question == "Use Prompt":
# Use content from Prompt field directly - requires user input
if not prompt or len(prompt.strip()) < 2:
shared.log.error(f'VQA interrogate: model="{model_name}" error="Please enter a prompt"')
shared.log.error(f'VQA caption: model="{model_name}" error="Please enter a prompt"')
shared.state.end(jobid)
return 'Error: Please enter a question or instruction in the Prompt field.'
question = prompt
@ -1316,7 +1316,7 @@ class VQA:
if raw_mapping in ("POINT_MODE", "DETECT_MODE"):
# These modes require user input in the prompt field
if not prompt or len(prompt.strip()) < 2:
shared.log.error(f'VQA interrogate: model="{model_name}" error="Please specify what to find in the prompt field"')
shared.log.error(f'VQA caption: model="{model_name}" error="Please specify what to find in the prompt field"')
shared.state.end(jobid)
return 'Error: Please specify what to find in the prompt field (e.g., "the red car" or "faces").'
# Convert friendly name to internal token (handles Point/Detect prefix)
@ -1328,12 +1328,12 @@ class VQA:
try:
if model_name is None:
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected')
shared.log.error(f'Caption: type=vlm model="{model_name}" no model selected')
shared.state.end(jobid)
return ''
vqa_model = vlm_models.get(model_name, None)
if vqa_model is None:
shared.log.error(f'Interrogate: type=vlm model="{model_name}" unknown')
shared.log.error(f'Caption: type=vlm model="{model_name}" unknown')
shared.state.end(jobid)
return ''
@ -1352,7 +1352,7 @@ class VQA:
answer = self._pix(question, image, vqa_model, model_name)
elif 'moondream3' in vqa_model.lower():
handler = 'moondream3'
from modules.interrogate import moondream3
from modules.caption import moondream3
answer = moondream3.predict(question, image, vqa_model, model_name, thinking_mode=thinking_mode)
elif 'moondream2' in vqa_model.lower():
handler = 'moondream'
@ -1368,15 +1368,15 @@ class VQA:
answer = self._smol(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
elif 'joytag' in vqa_model.lower():
handler = 'joytag'
from modules.interrogate import joytag
from modules.caption import joytag
answer = joytag.predict(image)
elif 'joycaption' in vqa_model.lower():
handler = 'joycaption'
from modules.interrogate import joycaption
from modules.caption import joycaption
answer = joycaption.predict(question, image, vqa_model)
elif 'deepseek' in vqa_model.lower():
handler = 'deepseek'
from modules.interrogate import deepseek
from modules.caption import deepseek
answer = deepseek.predict(question, image, vqa_model)
elif 'paligemma' in vqa_model.lower():
handler = 'paligemma'
@ -1399,7 +1399,7 @@ class VQA:
errors.display(e, 'VQA')
answer = 'error'
if shared.opts.interrogate_offload and self.model is not None:
if shared.opts.caption_offload and self.model is not None:
sd_models.move_model(self.model, devices.cpu, force=True)
devices.torch_gc(force=True, reason='vqa')
@ -1412,16 +1412,17 @@ class VQA:
points = self.last_detection_data.get('points', None)
if detections or points:
self.last_annotated_image = vqa_detection.draw_bounding_boxes(image, detections or [], points)
debug(f'VQA interrogate: handler={handler} created annotated image detections={len(detections) if detections else 0} points={len(points) if points else 0}')
debug(f'VQA caption: handler={handler} created annotated image detections={len(detections) if detections else 0} points={len(points) if points else 0}')
debug(f'VQA interrogate: handler={handler} response_after_clean="{answer}" has_annotation={self.last_annotated_image is not None}')
debug(f'VQA caption: handler={handler} response_after_clean="{answer}" has_annotation={self.last_annotated_image is not None}')
t1 = time.time()
if not quiet:
shared.log.debug(f'Interrogate: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}')
shared.log.debug(f'Caption: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}')
self._generation_overrides = None # Clear per-request overrides
shared.state.end(jobid)
return answer
def batch(self, model_name, system_prompt, batch_files, batch_folder, batch_str, question, prompt, write, append, recursive, prefill=None, thinking_mode=False):
class BatchWriter:
def __init__(self, folder, mode='w'):
@ -1450,15 +1451,15 @@ class VQA:
from modules.files_cache import list_files
files += list(list_files(batch_str, ext_filter=['.png', '.jpg', '.jpeg', '.webp', '.jxl'], recursive=recursive))
if len(files) == 0:
shared.log.warning('Interrogate batch: type=vlm no images')
shared.log.warning('Caption batch: type=vlm no images')
return ''
jobid = shared.state.begin('Interrogate batch')
jobid = shared.state.begin('Caption batch')
prompts = []
if write:
mode = 'w' if not append else 'a'
writer = BatchWriter(os.path.dirname(files[0]), mode=mode)
orig_offload = shared.opts.interrogate_offload
shared.opts.interrogate_offload = False
orig_offload = shared.opts.caption_offload
shared.opts.caption_offload = False
import rich.progress as rp
pbar = rp.Progress(rp.TextColumn('[cyan]Caption:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
with pbar:
@ -1469,7 +1470,7 @@ class VQA:
if shared.state.interrupted:
break
img = Image.open(file)
caption = self.interrogate(question, system_prompt, prompt, img, model_name, prefill, thinking_mode, quiet=True)
caption = self.caption(question, system_prompt, prompt, img, model_name, prefill, thinking_mode, quiet=True)
# Save annotated image if available
if self.last_annotated_image and write:
annotated_path = os.path.splitext(file)[0] + "_annotated.png"
@ -1478,10 +1479,10 @@ class VQA:
if write:
writer.add(file, caption)
except Exception as e:
shared.log.error(f'Interrogate batch: {e}')
shared.log.error(f'Caption batch: {e}')
if write:
writer.close()
shared.opts.interrogate_offload = orig_offload
shared.opts.caption_offload = orig_offload
shared.state.end(jobid)
return '\n\n'.join(prompts)
@ -1499,8 +1500,9 @@ def get_instance() -> VQA:
# Backwards-compatible module-level functions
def interrogate(*args, **kwargs):
return get_instance().interrogate(*args, **kwargs)
def caption(*args, **kwargs):
return get_instance().caption(*args, **kwargs)
def unload_model():

View File

@ -10,8 +10,8 @@ from PIL import Image
from modules import shared, devices, errors
# Debug logging - enable with SD_INTERROGATE_DEBUG environment variable
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
# Debug logging - enable with SD_CAPTION_DEBUG environment variable
debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
re_special = re.compile(r'([\\()])')
@ -405,7 +405,7 @@ def tag(image: Image.Image, model_name: str = None, **kwargs) -> str:
result = tagger.predict(image, **kwargs)
shared.log.debug(f'WaifuDiffusion: complete time={time.time()-t0:.2f} tags={len(result.split(", ")) if result else 0}')
# Offload model if setting enabled
if shared.opts.interrogate_offload:
if shared.opts.caption_offload:
tagger.unload()
except Exception as e:
result = f"Exception {type(e)}"

View File

@ -1,48 +0,0 @@
import time
from PIL import Image
from modules import shared
def interrogate(image):
if isinstance(image, list):
image = image[0] if len(image) > 0 else None
if isinstance(image, dict) and 'name' in image:
image = Image.open(image['name'])
if image is None:
shared.log.error('Interrogate: no image provided')
return ''
t0 = time.time()
if shared.opts.interrogate_default_type == 'OpenCLiP':
shared.log.info(f'Interrogate: type={shared.opts.interrogate_default_type} clip="{shared.opts.interrogate_clip_model}" blip="{shared.opts.interrogate_blip_model}" mode="{shared.opts.interrogate_clip_mode}"')
from modules.interrogate import openclip
openclip.load_interrogator(clip_model=shared.opts.interrogate_clip_model, blip_model=shared.opts.interrogate_blip_model)
openclip.update_interrogate_params()
prompt = openclip.interrogate(image, mode=shared.opts.interrogate_clip_mode)
shared.log.debug(f'Interrogate: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
elif shared.opts.interrogate_default_type == 'Tagger':
shared.log.info(f'Interrogate: type={shared.opts.interrogate_default_type} model="{shared.opts.waifudiffusion_model}"')
from modules.interrogate import tagger
prompt = tagger.tag(
image=image,
model_name=shared.opts.waifudiffusion_model,
general_threshold=shared.opts.tagger_threshold,
character_threshold=shared.opts.waifudiffusion_character_threshold,
include_rating=shared.opts.tagger_include_rating,
exclude_tags=shared.opts.tagger_exclude_tags,
max_tags=shared.opts.tagger_max_tags,
sort_alpha=shared.opts.tagger_sort_alpha,
use_spaces=shared.opts.tagger_use_spaces,
escape_brackets=shared.opts.tagger_escape_brackets,
)
shared.log.debug(f'Interrogate: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
elif shared.opts.interrogate_default_type == 'VLM':
shared.log.info(f'Interrogate: type={shared.opts.interrogate_default_type} vlm="{shared.opts.interrogate_vlm_model}" prompt="{shared.opts.interrogate_vlm_prompt}"')
from modules.interrogate import vqa
prompt = vqa.interrogate(image=image, model_name=shared.opts.interrogate_vlm_model, question=shared.opts.interrogate_vlm_prompt, prompt=None, system_prompt=shared.opts.interrogate_vlm_system)
shared.log.debug(f'Interrogate: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
else:
shared.log.error(f'Interrogate: type="{shared.opts.interrogate_default_type}" unknown')
return ''