mirror of https://github.com/vladmandic/automatic
refactor: rename interrogate module to caption
Move all caption-related modules from modules/interrogate/ to modules/caption/ for better naming consistency: - Rename deepbooru, deepseek, joycaption, joytag, moondream3, openclip, tagger, vqa, vqa_detection, waifudiffusion modules - Add new caption.py dispatcher module - Remove old interrogate.py (functionality moved to caption.py)pull/4613/head
parent
83fa8e39ba
commit
5183ebec58
|
|
@ -0,0 +1,48 @@
|
|||
import time
|
||||
from PIL import Image
|
||||
from modules import shared
|
||||
|
||||
|
||||
def caption(image):
|
||||
if isinstance(image, list):
|
||||
image = image[0] if len(image) > 0 else None
|
||||
if isinstance(image, dict) and 'name' in image:
|
||||
image = Image.open(image['name'])
|
||||
if image is None:
|
||||
shared.log.error('Caption: no image provided')
|
||||
return ''
|
||||
t0 = time.time()
|
||||
if shared.opts.caption_default_type == 'OpenCLiP':
|
||||
shared.log.info(f'Caption: type={shared.opts.caption_default_type} clip="{shared.opts.caption_openclip_model}" blip="{shared.opts.caption_openclip_blip_model}" mode="{shared.opts.caption_openclip_mode}"')
|
||||
from modules.caption import openclip
|
||||
openclip.load_captioner(clip_model=shared.opts.caption_openclip_model, blip_model=shared.opts.caption_openclip_blip_model)
|
||||
openclip.update_caption_params()
|
||||
prompt = openclip.caption(image, mode=shared.opts.caption_openclip_mode)
|
||||
shared.log.debug(f'Caption: time={time.time()-t0:.2f} answer="{prompt}"')
|
||||
return prompt
|
||||
elif shared.opts.caption_default_type == 'Tagger':
|
||||
shared.log.info(f'Caption: type={shared.opts.caption_default_type} model="{shared.opts.waifudiffusion_model}"')
|
||||
from modules.caption import tagger
|
||||
prompt = tagger.tag(
|
||||
image=image,
|
||||
model_name=shared.opts.waifudiffusion_model,
|
||||
general_threshold=shared.opts.tagger_threshold,
|
||||
character_threshold=shared.opts.waifudiffusion_character_threshold,
|
||||
include_rating=shared.opts.tagger_include_rating,
|
||||
exclude_tags=shared.opts.tagger_exclude_tags,
|
||||
max_tags=shared.opts.tagger_max_tags,
|
||||
sort_alpha=shared.opts.tagger_sort_alpha,
|
||||
use_spaces=shared.opts.tagger_use_spaces,
|
||||
escape_brackets=shared.opts.tagger_escape_brackets,
|
||||
)
|
||||
shared.log.debug(f'Caption: time={time.time()-t0:.2f} answer="{prompt}"')
|
||||
return prompt
|
||||
elif shared.opts.caption_default_type == 'VLM':
|
||||
shared.log.info(f'Caption: type={shared.opts.caption_default_type} vlm="{shared.opts.caption_vlm_model}" prompt="{shared.opts.caption_vlm_prompt}"')
|
||||
from modules.caption import vqa
|
||||
prompt = vqa.caption(image=image, model_name=shared.opts.caption_vlm_model, question=shared.opts.caption_vlm_prompt, prompt=None, system_prompt=shared.opts.caption_vlm_system)
|
||||
shared.log.debug(f'Caption: time={time.time()-t0:.2f} answer="{prompt}"')
|
||||
return prompt
|
||||
else:
|
||||
shared.log.error(f'Caption: type="{shared.opts.caption_default_type}" unknown')
|
||||
return ''
|
||||
|
|
@ -19,7 +19,7 @@ class DeepDanbooru:
|
|||
if self.model is not None:
|
||||
return
|
||||
model_path = os.path.join(shared.opts.clip_models_path, "DeepDanbooru")
|
||||
shared.log.debug(f'Interrogate load: module=DeepDanbooru folder="{model_path}"')
|
||||
shared.log.debug(f'Caption load: module=DeepDanbooru folder="{model_path}"')
|
||||
files = modelloader.load_models(
|
||||
model_path=model_path,
|
||||
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
|
||||
|
|
@ -27,7 +27,7 @@ class DeepDanbooru:
|
|||
download_name='model-resnet_custom_v3.pt',
|
||||
)
|
||||
|
||||
from modules.interrogate.deepbooru_model import DeepDanbooruModel
|
||||
from modules.caption.deepbooru_model import DeepDanbooruModel
|
||||
self.model = DeepDanbooruModel()
|
||||
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
||||
self.model.eval()
|
||||
|
|
@ -38,7 +38,7 @@ class DeepDanbooru:
|
|||
self.model.to(devices.device)
|
||||
|
||||
def stop(self):
|
||||
if shared.opts.interrogate_offload:
|
||||
if shared.opts.caption_offload:
|
||||
self.model.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
|
||||
|
|
@ -32,11 +32,11 @@ def load(repo: str):
|
|||
"""Load DeepSeek VL2 model (experimental)."""
|
||||
global vl_gpt, vl_chat_processor, loaded_repo # pylint: disable=global-statement
|
||||
if not shared.cmd_opts.experimental:
|
||||
shared.log.error(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}" is experimental-only')
|
||||
shared.log.error(f'Caption: type=vlm model="DeepSeek VL2" repo="{repo}" is experimental-only')
|
||||
return False
|
||||
folder = os.path.join(paths.script_path, 'repositories', 'deepseek-vl2')
|
||||
if not os.path.exists(folder):
|
||||
shared.log.error(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}" deepseek-vl2 repo not found')
|
||||
shared.log.error(f'Caption: type=vlm model="DeepSeek VL2" repo="{repo}" deepseek-vl2 repo not found')
|
||||
return False
|
||||
if vl_gpt is None or loaded_repo != repo:
|
||||
sys.modules['attrdict'] = fake_attrdict
|
||||
|
|
@ -53,7 +53,7 @@ def load(repo: str):
|
|||
vl_gpt.to(dtype=devices.dtype)
|
||||
vl_gpt.eval()
|
||||
loaded_repo = repo
|
||||
shared.log.info(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}"')
|
||||
shared.log.info(f'Caption: type=vlm model="DeepSeek VL2" repo="{repo}"')
|
||||
sd_models.move_model(vl_gpt, devices.device)
|
||||
return True
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ def predict(question, image, repo):
|
|||
pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
|
||||
bos_token_id=vl_chat_processor.tokenizer.bos_token_id,
|
||||
eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
|
||||
max_new_tokens=shared.opts.interrogate_vlm_max_length,
|
||||
max_new_tokens=shared.opts.caption_vlm_max_length,
|
||||
do_sample=False,
|
||||
use_cache=True
|
||||
)
|
||||
|
|
@ -64,7 +64,7 @@ def load(repo: str = None):
|
|||
if llava_model is None or opts.repo != repo:
|
||||
opts.repo = repo
|
||||
llava_model = None
|
||||
shared.log.info(f'Interrogate: type=vlm model="JoyCaption" {str(opts)}')
|
||||
shared.log.info(f'Caption: type=vlm model="JoyCaption" {str(opts)}')
|
||||
processor = AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
|
||||
quant_args = model_quant.create_config(module='LLM')
|
||||
llava_model = LlavaForConditionalGeneration.from_pretrained(
|
||||
|
|
@ -92,7 +92,7 @@ def unload():
|
|||
|
||||
@torch.no_grad()
|
||||
def predict(question: str, image, vqa_model: str = None) -> str:
|
||||
opts.max_new_tokens = shared.opts.interrogate_vlm_max_length
|
||||
opts.max_new_tokens = shared.opts.caption_vlm_max_length
|
||||
load(vqa_model)
|
||||
|
||||
if len(question) < 2:
|
||||
|
|
@ -121,7 +121,7 @@ def predict(question: str, image, vqa_model: str = None) -> str:
|
|||
)[0]
|
||||
generate_ids = generate_ids[inputs['input_ids'].shape[1]:] # Trim off the prompt
|
||||
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Decode the caption
|
||||
if shared.opts.interrogate_offload:
|
||||
if shared.opts.caption_offload:
|
||||
sd_models.move_model(llava_model, devices.cpu, force=True)
|
||||
caption = caption.replace('\n\n', '\n').strip()
|
||||
return caption
|
||||
|
|
@ -1044,7 +1044,7 @@ def load():
|
|||
model.eval()
|
||||
with open(os.path.join(folder, 'top_tags.txt'), 'r', encoding='utf8') as f:
|
||||
tags = [line.strip() for line in f.readlines() if line.strip()]
|
||||
shared.log.info(f'Interrogate: type=vlm model="JoyTag" repo="{MODEL_REPO}" tags={len(tags)}')
|
||||
shared.log.info(f'Caption: type=vlm model="JoyTag" repo="{MODEL_REPO}" tags={len(tags)}')
|
||||
sd_models.move_model(model, devices.device)
|
||||
|
||||
|
||||
|
|
@ -1068,7 +1068,7 @@ def predict(image: Image.Image):
|
|||
preds = model({'image': image_tensor})
|
||||
tag_preds = preds['tags'].sigmoid().cpu()
|
||||
scores = {tags[i]: tag_preds[0][i] for i in range(len(tags))}
|
||||
if shared.opts.interrogate_score:
|
||||
if shared.opts.tagger_show_scores:
|
||||
predicted_tags = [f'{tag}:{score:.2f}' for tag, score in scores.items() if score > THRESHOLD]
|
||||
else:
|
||||
predicted_tags = [tag for tag, score in scores.items() if score > THRESHOLD]
|
||||
|
|
@ -7,11 +7,11 @@ import re
|
|||
import transformers
|
||||
from PIL import Image
|
||||
from modules import shared, devices, sd_models
|
||||
from modules.interrogate import vqa_detection
|
||||
from modules.caption import vqa_detection
|
||||
|
||||
|
||||
# Debug logging - function-based to avoid circular import
|
||||
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
|
||||
debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None
|
||||
|
||||
def debug(*args, **kwargs):
|
||||
if debug_enabled:
|
||||
|
|
@ -30,12 +30,12 @@ def get_settings():
|
|||
Moondream 3 accepts: temperature, top_p, max_tokens
|
||||
"""
|
||||
settings = {}
|
||||
if shared.opts.interrogate_vlm_max_length > 0:
|
||||
settings['max_tokens'] = shared.opts.interrogate_vlm_max_length
|
||||
if shared.opts.interrogate_vlm_temperature > 0:
|
||||
settings['temperature'] = shared.opts.interrogate_vlm_temperature
|
||||
if shared.opts.interrogate_vlm_top_p > 0:
|
||||
settings['top_p'] = shared.opts.interrogate_vlm_top_p
|
||||
if shared.opts.caption_vlm_max_length > 0:
|
||||
settings['max_tokens'] = shared.opts.caption_vlm_max_length
|
||||
if shared.opts.caption_vlm_temperature > 0:
|
||||
settings['temperature'] = shared.opts.caption_vlm_temperature
|
||||
if shared.opts.caption_vlm_top_p > 0:
|
||||
settings['top_p'] = shared.opts.caption_vlm_top_p
|
||||
return settings if settings else None
|
||||
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ def load_model(repo: str):
|
|||
global moondream3_model, loaded # pylint: disable=global-statement
|
||||
|
||||
if moondream3_model is None or loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
moondream3_model = None
|
||||
|
||||
moondream3_model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
|
|
@ -84,7 +84,7 @@ def encode_image(image: Image.Image, cache_key: str = None):
|
|||
Encoded image tensor
|
||||
"""
|
||||
if cache_key and cache_key in image_cache:
|
||||
debug(f'VQA interrogate: handler=moondream3 using cached encoding for cache_key="{cache_key}"')
|
||||
debug(f'VQA caption: handler=moondream3 using cached encoding for cache_key="{cache_key}"')
|
||||
return image_cache[cache_key]
|
||||
|
||||
model = load_model(loaded)
|
||||
|
|
@ -94,7 +94,7 @@ def encode_image(image: Image.Image, cache_key: str = None):
|
|||
|
||||
if cache_key:
|
||||
image_cache[cache_key] = encoded
|
||||
debug(f'VQA interrogate: handler=moondream3 cached encoding cache_key="{cache_key}" cache_size={len(image_cache)}')
|
||||
debug(f'VQA caption: handler=moondream3 cached encoding cache_key="{cache_key}" cache_size={len(image_cache)}')
|
||||
|
||||
return encoded
|
||||
|
||||
|
|
@ -129,7 +129,7 @@ def query(image: Image.Image, question: str, repo: str, stream: bool = False,
|
|||
if max_tokens is not None:
|
||||
settings['max_tokens'] = max_tokens
|
||||
|
||||
debug(f'VQA interrogate: handler=moondream3 method=query question="{question}" stream={stream} settings={settings}')
|
||||
debug(f'VQA caption: handler=moondream3 method=query question="{question}" stream={stream} settings={settings}')
|
||||
|
||||
# Use cached encoding if requested
|
||||
if use_cache:
|
||||
|
|
@ -150,12 +150,12 @@ def query(image: Image.Image, question: str, repo: str, stream: bool = False,
|
|||
# Log response structure (for non-streaming)
|
||||
if not stream:
|
||||
if isinstance(response, dict):
|
||||
debug(f'VQA interrogate: handler=moondream3 response_type=dict keys={list(response.keys())}')
|
||||
debug(f'VQA caption: handler=moondream3 response_type=dict keys={list(response.keys())}')
|
||||
if 'reasoning' in response:
|
||||
reasoning_text = response['reasoning'].get('text', '')[:100] + '...' if len(response['reasoning'].get('text', '')) > 100 else response['reasoning'].get('text', '')
|
||||
debug(f'VQA interrogate: handler=moondream3 reasoning="{reasoning_text}"')
|
||||
debug(f'VQA caption: handler=moondream3 reasoning="{reasoning_text}"')
|
||||
if 'answer' in response:
|
||||
debug(f'VQA interrogate: handler=moondream3 answer="{response["answer"]}"')
|
||||
debug(f'VQA caption: handler=moondream3 answer="{response["answer"]}"')
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -188,7 +188,7 @@ def caption(image: Image.Image, repo: str, length: str = 'normal', stream: bool
|
|||
if max_tokens is not None:
|
||||
settings['max_tokens'] = max_tokens
|
||||
|
||||
debug(f'VQA interrogate: handler=moondream3 method=caption length={length} stream={stream} settings={settings}')
|
||||
debug(f'VQA caption: handler=moondream3 method=caption length={length} stream={stream} settings={settings}')
|
||||
|
||||
with devices.inference_context():
|
||||
response = model.caption(
|
||||
|
|
@ -200,7 +200,7 @@ def caption(image: Image.Image, repo: str, length: str = 'normal', stream: bool
|
|||
|
||||
# Log response structure (for non-streaming)
|
||||
if not stream and isinstance(response, dict):
|
||||
debug(f'VQA interrogate: handler=moondream3 response_type=dict keys={list(response.keys())}')
|
||||
debug(f'VQA caption: handler=moondream3 response_type=dict keys={list(response.keys())}')
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -220,21 +220,21 @@ def point(image: Image.Image, object_name: str, repo: str):
|
|||
"""
|
||||
model = load_model(repo)
|
||||
|
||||
debug(f'VQA interrogate: handler=moondream3 method=point object_name="{object_name}"')
|
||||
debug(f'VQA caption: handler=moondream3 method=point object_name="{object_name}"')
|
||||
|
||||
with devices.inference_context():
|
||||
result = model.point(image, object_name)
|
||||
|
||||
debug(f'VQA interrogate: handler=moondream3 point_raw_result="{result}" type={type(result)}')
|
||||
debug(f'VQA caption: handler=moondream3 point_raw_result="{result}" type={type(result)}')
|
||||
if isinstance(result, dict):
|
||||
debug(f'VQA interrogate: handler=moondream3 point_raw_result_keys={list(result.keys())}')
|
||||
debug(f'VQA caption: handler=moondream3 point_raw_result_keys={list(result.keys())}')
|
||||
|
||||
points = vqa_detection.parse_points(result)
|
||||
if points:
|
||||
debug(f'VQA interrogate: handler=moondream3 point_result={len(points)} points found')
|
||||
debug(f'VQA caption: handler=moondream3 point_result={len(points)} points found')
|
||||
return points
|
||||
|
||||
debug('VQA interrogate: handler=moondream3 point_result=not found')
|
||||
debug('VQA caption: handler=moondream3 point_result=not found')
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -257,17 +257,17 @@ def detect(image: Image.Image, object_name: str, repo: str, max_objects: int = 1
|
|||
"""
|
||||
model = load_model(repo)
|
||||
|
||||
debug(f'VQA interrogate: handler=moondream3 method=detect object_name="{object_name}" max_objects={max_objects}')
|
||||
debug(f'VQA caption: handler=moondream3 method=detect object_name="{object_name}" max_objects={max_objects}')
|
||||
|
||||
with devices.inference_context():
|
||||
result = model.detect(image, object_name)
|
||||
|
||||
debug(f'VQA interrogate: handler=moondream3 detect_raw_result="{result}" type={type(result)}')
|
||||
debug(f'VQA caption: handler=moondream3 detect_raw_result="{result}" type={type(result)}')
|
||||
if isinstance(result, dict):
|
||||
debug(f'VQA interrogate: handler=moondream3 detect_raw_result_keys={list(result.keys())}')
|
||||
debug(f'VQA caption: handler=moondream3 detect_raw_result_keys={list(result.keys())}')
|
||||
|
||||
detections = vqa_detection.parse_detections(result, object_name, max_objects)
|
||||
debug(f'VQA interrogate: handler=moondream3 detect_result={len(detections)} objects found')
|
||||
debug(f'VQA caption: handler=moondream3 detect_result={len(detections)} objects found')
|
||||
return detections
|
||||
|
||||
|
||||
|
|
@ -291,7 +291,7 @@ def predict(question: str, image: Image.Image, repo: str, model_name: str = None
|
|||
Response string (detection data stored on VQA singleton instance.last_detection_data)
|
||||
(or generator if stream=True for query/caption modes)
|
||||
"""
|
||||
debug(f'VQA interrogate: handler=moondream3 model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None} mode={mode} stream={stream}')
|
||||
debug(f'VQA caption: handler=moondream3 model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None} mode={mode} stream={stream}')
|
||||
|
||||
# Clean question
|
||||
question = question.replace('<', '').replace('>', '').replace('_', ' ') if question else ''
|
||||
|
|
@ -331,7 +331,7 @@ def predict(question: str, image: Image.Image, repo: str, model_name: str = None
|
|||
else:
|
||||
mode = 'query'
|
||||
|
||||
debug(f'VQA interrogate: handler=moondream3 mode_selected={mode}')
|
||||
debug(f'VQA caption: handler=moondream3 mode_selected={mode}')
|
||||
|
||||
# Dispatch to appropriate method
|
||||
try:
|
||||
|
|
@ -348,10 +348,10 @@ def predict(question: str, image: Image.Image, repo: str, model_name: str = None
|
|||
object_name = re.sub(rf'\b{phrase}\b', '', object_name, flags=re.IGNORECASE)
|
||||
object_name = re.sub(r'[?.!,]', '', object_name).strip()
|
||||
object_name = re.sub(r'^\s*the\s+', '', object_name, flags=re.IGNORECASE)
|
||||
debug(f'VQA interrogate: handler=moondream3 point_extracted_object="{object_name}"')
|
||||
debug(f'VQA caption: handler=moondream3 point_extracted_object="{object_name}"')
|
||||
result = point(image, object_name, repo)
|
||||
if result:
|
||||
from modules.interrogate import vqa
|
||||
from modules.caption import vqa
|
||||
vqa.get_instance().last_detection_data = {'points': result}
|
||||
return vqa_detection.format_points_text(result)
|
||||
return "Object not found"
|
||||
|
|
@ -364,11 +364,11 @@ def predict(question: str, image: Image.Image, repo: str, model_name: str = None
|
|||
object_name = re.sub(r'^\s*the\s+', '', object_name, flags=re.IGNORECASE)
|
||||
if ' and ' in object_name.lower():
|
||||
object_name = re.split(r'\s+and\s+', object_name, flags=re.IGNORECASE)[0].strip()
|
||||
debug(f'VQA interrogate: handler=moondream3 detect_extracted_object="{object_name}"')
|
||||
debug(f'VQA caption: handler=moondream3 detect_extracted_object="{object_name}"')
|
||||
|
||||
results = detect(image, object_name, repo, max_objects=kwargs.get('max_objects', 10))
|
||||
if results:
|
||||
from modules.interrogate import vqa
|
||||
from modules.caption import vqa
|
||||
vqa.get_instance().last_detection_data = {'detections': results}
|
||||
return vqa_detection.format_detections_text(results)
|
||||
return "No objects detected"
|
||||
|
|
@ -377,7 +377,7 @@ def predict(question: str, image: Image.Image, repo: str, model_name: str = None
|
|||
question = "Describe this image."
|
||||
response = query(image, question, repo, stream=stream, use_cache=use_cache, reasoning=thinking_mode)
|
||||
|
||||
debug(f'VQA interrogate: handler=moondream3 response_before_clean="{response}"')
|
||||
debug(f'VQA caption: handler=moondream3 response_before_clean="{response}"')
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -390,7 +390,7 @@ def clear_cache():
|
|||
"""Clear image encoding cache."""
|
||||
cache_size = len(image_cache)
|
||||
image_cache.clear()
|
||||
debug(f'VQA interrogate: handler=moondream3 cleared image cache cache_size_was={cache_size}')
|
||||
debug(f'VQA caption: handler=moondream3 cleared image cache cache_size_was={cache_size}')
|
||||
shared.log.debug(f'Moondream3: Cleared image cache ({cache_size} entries)')
|
||||
|
||||
|
||||
|
|
@ -8,7 +8,7 @@ from PIL import Image
|
|||
from modules import devices, shared, errors, sd_models
|
||||
|
||||
|
||||
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
|
||||
debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None
|
||||
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
|
||||
|
||||
# Per-request overrides for API calls
|
||||
|
|
@ -19,7 +19,7 @@ def get_clip_setting(name):
|
|||
"""Get CLIP setting with per-request override support.
|
||||
|
||||
Args:
|
||||
name: Setting name without 'interrogate_clip_' prefix (e.g., 'min_flavors', 'max_length')
|
||||
name: Setting name without 'caption_openclip_' prefix (e.g., 'min_flavors', 'max_length')
|
||||
|
||||
Returns:
|
||||
Override value if set, otherwise the value from shared.opts
|
||||
|
|
@ -28,7 +28,7 @@ def get_clip_setting(name):
|
|||
value = _clip_overrides.get(name)
|
||||
if value is not None:
|
||||
return value
|
||||
return getattr(shared.opts, f'interrogate_clip_{name}')
|
||||
return getattr(shared.opts, f'caption_openclip_{name}')
|
||||
|
||||
|
||||
def _apply_blip2_fix(model, processor):
|
||||
|
|
@ -87,13 +87,14 @@ class BatchWriter:
|
|||
self.file.close()
|
||||
|
||||
|
||||
def update_interrogate_params():
|
||||
def update_caption_params():
|
||||
if ci is not None:
|
||||
ci.caption_max_length = get_clip_setting('max_length')
|
||||
ci.chunk_size = get_clip_setting('chunk_size')
|
||||
ci.flavor_intermediate_count = get_clip_setting('flavor_count')
|
||||
ci.clip_offload = shared.opts.interrogate_offload
|
||||
ci.caption_offload = shared.opts.interrogate_offload
|
||||
ci.clip_offload = shared.opts.caption_offload
|
||||
ci.caption_offload = shared.opts.caption_offload
|
||||
|
||||
|
||||
|
||||
def get_clip_models():
|
||||
|
|
@ -104,12 +105,12 @@ def refresh_clip_models():
|
|||
global clip_models # pylint: disable=global-statement
|
||||
import open_clip
|
||||
models = sorted(open_clip.list_pretrained())
|
||||
shared.log.debug(f'Interrogate: pkg=openclip version={open_clip.__version__} models={len(models)}')
|
||||
shared.log.debug(f'Caption: pkg=openclip version={open_clip.__version__} models={len(models)}')
|
||||
clip_models = ['/'.join(x) for x in models]
|
||||
return clip_models
|
||||
|
||||
|
||||
def load_interrogator(clip_model, blip_model):
|
||||
def load_captioner(clip_model, blip_model):
|
||||
from installer import install
|
||||
install('clip_interrogator==0.6.0')
|
||||
import clip_interrogator
|
||||
|
|
@ -120,20 +121,21 @@ def load_interrogator(clip_model, blip_model):
|
|||
device = devices.get_optimal_device()
|
||||
cache_path = shared.opts.clip_models_path
|
||||
shared.log.info(f'CLIP load: clip="{clip_model}" blip="{blip_model}" device={device}')
|
||||
debug_log(f'CLIP load: cache_path="{cache_path}" max_length={shared.opts.interrogate_clip_max_length} chunk_size={shared.opts.interrogate_clip_chunk_size} flavor_count={shared.opts.interrogate_clip_flavor_count} offload={shared.opts.interrogate_offload}')
|
||||
interrogator_config = clip_interrogator.Config(
|
||||
debug_log(f'CLIP load: cache_path="{cache_path}" max_length={shared.opts.caption_openclip_max_length} chunk_size={shared.opts.caption_openclip_chunk_size} flavor_count={shared.opts.caption_openclip_flavor_count} offload={shared.opts.caption_offload}')
|
||||
captioner_config = clip_interrogator.Config(
|
||||
device=device,
|
||||
cache_path=cache_path,
|
||||
clip_model_name=clip_model,
|
||||
caption_model_name=blip_model,
|
||||
quiet=True,
|
||||
caption_max_length=shared.opts.interrogate_clip_max_length,
|
||||
chunk_size=shared.opts.interrogate_clip_chunk_size,
|
||||
flavor_intermediate_count=shared.opts.interrogate_clip_flavor_count,
|
||||
clip_offload=shared.opts.interrogate_offload,
|
||||
caption_offload=shared.opts.interrogate_offload,
|
||||
caption_max_length=shared.opts.caption_openclip_max_length,
|
||||
chunk_size=shared.opts.caption_openclip_chunk_size,
|
||||
flavor_intermediate_count=shared.opts.caption_openclip_flavor_count,
|
||||
clip_offload=shared.opts.caption_offload,
|
||||
caption_offload=shared.opts.caption_offload,
|
||||
)
|
||||
ci = clip_interrogator.Interrogator(interrogator_config)
|
||||
ci = clip_interrogator.Interrogator(captioner_config)
|
||||
|
||||
if blip_model.startswith('blip2-'):
|
||||
_apply_blip2_fix(ci.caption_model, ci.caption_processor)
|
||||
shared.log.debug(f'CLIP load: time={time.time()-t0:.2f}')
|
||||
|
|
@ -145,12 +147,14 @@ def load_interrogator(clip_model, blip_model):
|
|||
ci.config.clip_model_name = clip_model
|
||||
ci.config.clip_model = None
|
||||
ci.load_clip_model()
|
||||
ci.clip_offloaded = True # Reset flag so _prepare_clip() will move model to device
|
||||
if blip_model != ci.config.caption_model_name:
|
||||
shared.log.info(f'CLIP load: blip="{blip_model}" reloading')
|
||||
debug_log(f'CLIP load: previous blip="{ci.config.caption_model_name}"')
|
||||
ci.config.caption_model_name = blip_model
|
||||
ci.config.caption_model = None
|
||||
ci.load_caption_model()
|
||||
ci.caption_offloaded = True # Reset flag so _prepare_caption() will move model to device
|
||||
if blip_model.startswith('blip2-'):
|
||||
_apply_blip2_fix(ci.caption_model, ci.caption_processor)
|
||||
shared.log.debug(f'CLIP load: time={time.time()-t0:.2f}')
|
||||
|
|
@ -159,7 +163,7 @@ def load_interrogator(clip_model, blip_model):
|
|||
|
||||
|
||||
def unload_clip_model():
|
||||
if ci is not None and shared.opts.interrogate_offload:
|
||||
if ci is not None and shared.opts.caption_offload:
|
||||
shared.log.debug('CLIP unload: offloading models to CPU')
|
||||
sd_models.move_model(ci.caption_model, devices.cpu)
|
||||
sd_models.move_model(ci.clip_model, devices.cpu)
|
||||
|
|
@ -169,7 +173,7 @@ def unload_clip_model():
|
|||
debug_log('CLIP unload: complete')
|
||||
|
||||
|
||||
def interrogate(image, mode, caption=None):
|
||||
def caption(image, mode, base_caption=None):
|
||||
if isinstance(image, list):
|
||||
image = image[0] if len(image) > 0 else None
|
||||
if isinstance(image, dict) and 'name' in image:
|
||||
|
|
@ -180,15 +184,17 @@ def interrogate(image, mode, caption=None):
|
|||
t0 = time.time()
|
||||
min_flavors = get_clip_setting('min_flavors')
|
||||
max_flavors = get_clip_setting('max_flavors')
|
||||
debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={caption is not None} min_flavors={min_flavors} max_flavors={max_flavors}')
|
||||
debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={base_caption is not None} min_flavors={min_flavors} max_flavors={max_flavors}')
|
||||
# NOTE: Method names like .interrogate(), .interrogate_classic(), etc. come from the external
|
||||
# clip-interrogator library (https://github.com/pharmapsychotic/clip-interrogator) and cannot be renamed.
|
||||
if mode == 'best':
|
||||
prompt = ci.interrogate(image, caption=caption, min_flavors=min_flavors, max_flavors=max_flavors)
|
||||
prompt = ci.interrogate(image, caption=base_caption, min_flavors=min_flavors, max_flavors=max_flavors)
|
||||
elif mode == 'caption':
|
||||
prompt = ci.generate_caption(image) if caption is None else caption
|
||||
prompt = ci.generate_caption(image) if base_caption is None else base_caption
|
||||
elif mode == 'classic':
|
||||
prompt = ci.interrogate_classic(image, caption=caption, max_flavors=max_flavors)
|
||||
prompt = ci.interrogate_classic(image, caption=base_caption, max_flavors=max_flavors)
|
||||
elif mode == 'fast':
|
||||
prompt = ci.interrogate_fast(image, caption=caption, max_flavors=max_flavors)
|
||||
prompt = ci.interrogate_fast(image, caption=base_caption, max_flavors=max_flavors)
|
||||
elif mode == 'negative':
|
||||
prompt = ci.interrogate_negative(image, max_flavors=max_flavors)
|
||||
else:
|
||||
|
|
@ -197,9 +203,10 @@ def interrogate(image, mode, caption=None):
|
|||
return prompt
|
||||
|
||||
|
||||
def interrogate_image(image, clip_model, blip_model, mode, overrides=None):
|
||||
|
||||
def caption_image(image, clip_model, blip_model, mode, overrides=None):
|
||||
global _clip_overrides # pylint: disable=global-statement
|
||||
jobid = shared.state.begin('Interrogate CLiP')
|
||||
jobid = shared.state.begin('Caption CLiP')
|
||||
t0 = time.time()
|
||||
shared.log.info(f'CLIP: mode="{mode}" clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}')
|
||||
if overrides:
|
||||
|
|
@ -211,17 +218,19 @@ def interrogate_image(image, clip_model, blip_model, mode, overrides=None):
|
|||
from modules.sd_models import apply_balanced_offload # prevent circular import
|
||||
apply_balanced_offload(shared.sd_model)
|
||||
debug_log('CLIP: applied balanced offload to sd_model')
|
||||
load_interrogator(clip_model, blip_model)
|
||||
# Apply overrides to loaded interrogator
|
||||
update_interrogate_params()
|
||||
load_captioner(clip_model, blip_model)
|
||||
# Apply overrides to loaded captioner
|
||||
update_caption_params()
|
||||
image = image.convert('RGB')
|
||||
prompt = interrogate(image, mode)
|
||||
prompt = caption(image, mode)
|
||||
if shared.opts.caption_offload:
|
||||
unload_clip_model()
|
||||
devices.torch_gc()
|
||||
shared.log.debug(f'CLIP: complete time={time.time()-t0:.2f}')
|
||||
except Exception as e:
|
||||
prompt = f"Exception {type(e)}"
|
||||
shared.log.error(f'CLIP: {e}')
|
||||
errors.display(e, 'Interrogate')
|
||||
errors.display(e, 'Caption')
|
||||
finally:
|
||||
# Clear per-request overrides
|
||||
_clip_overrides = None
|
||||
|
|
@ -229,7 +238,8 @@ def interrogate_image(image, clip_model, blip_model, mode, overrides=None):
|
|||
return prompt
|
||||
|
||||
|
||||
def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_model, mode, write, append, recursive):
|
||||
|
||||
def caption_batch(batch_files, batch_folder, batch_str, clip_model, blip_model, mode, write, append, recursive):
|
||||
files = []
|
||||
if batch_files is not None:
|
||||
files += [f.name for f in batch_files]
|
||||
|
|
@ -244,10 +254,10 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
|
|||
t0 = time.time()
|
||||
shared.log.info(f'CLIP batch: mode="{mode}" images={len(files)} clip="{clip_model}" blip="{blip_model}" write={write} append={append}')
|
||||
debug_log(f'CLIP batch: recursive={recursive} files={files[:5]}{"..." if len(files) > 5 else ""}')
|
||||
jobid = shared.state.begin('Interrogate batch')
|
||||
jobid = shared.state.begin('Caption batch')
|
||||
prompts = []
|
||||
|
||||
load_interrogator(clip_model, blip_model)
|
||||
load_captioner(clip_model, blip_model)
|
||||
if write:
|
||||
file_mode = 'w' if not append else 'a'
|
||||
writer = BatchWriter(os.path.dirname(files[0]), mode=file_mode)
|
||||
|
|
@ -263,7 +273,7 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
|
|||
shared.log.info('CLIP batch: interrupted')
|
||||
break
|
||||
image = Image.open(file).convert('RGB')
|
||||
prompt = interrogate(image, mode)
|
||||
prompt = caption(image, mode)
|
||||
prompts.append(prompt)
|
||||
if write:
|
||||
writer.add(file, prompt)
|
||||
|
|
@ -278,10 +288,11 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
|
|||
return '\n\n'.join(prompts)
|
||||
|
||||
|
||||
|
||||
def analyze_image(image, clip_model, blip_model):
|
||||
t0 = time.time()
|
||||
shared.log.info(f'CLIP analyze: clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}')
|
||||
load_interrogator(clip_model, blip_model)
|
||||
load_captioner(clip_model, blip_model)
|
||||
image = image.convert('RGB')
|
||||
image_features = ci.image_to_features(image)
|
||||
debug_log(f'CLIP analyze: features shape={image_features.shape if hasattr(image_features, "shape") else "unknown"}')
|
||||
|
|
@ -8,7 +8,7 @@ DEEPBOORU_MODEL = "DeepBooru"
|
|||
|
||||
def get_models() -> list:
|
||||
"""Return combined list: DeepBooru + WaifuDiffusion models."""
|
||||
from modules.interrogate import waifudiffusion
|
||||
from modules.caption import waifudiffusion
|
||||
return [DEEPBOORU_MODEL] + waifudiffusion.get_models()
|
||||
|
||||
|
||||
|
|
@ -25,16 +25,16 @@ def is_deepbooru(model_name: str) -> bool:
|
|||
def load_model(model_name: str) -> bool:
|
||||
"""Load appropriate backend."""
|
||||
if is_deepbooru(model_name):
|
||||
from modules.interrogate import deepbooru
|
||||
from modules.caption import deepbooru
|
||||
return deepbooru.load_model()
|
||||
else:
|
||||
from modules.interrogate import waifudiffusion
|
||||
from modules.caption import waifudiffusion
|
||||
return waifudiffusion.load_model(model_name)
|
||||
|
||||
|
||||
def unload_model():
|
||||
"""Unload both backends to ensure memory is freed."""
|
||||
from modules.interrogate import deepbooru, waifudiffusion
|
||||
from modules.caption import deepbooru, waifudiffusion
|
||||
deepbooru.unload_model()
|
||||
waifudiffusion.unload_model()
|
||||
|
||||
|
|
@ -54,10 +54,10 @@ def tag(image, model_name: str = None, **kwargs) -> str:
|
|||
model_name = shared.opts.waifudiffusion_model
|
||||
|
||||
if is_deepbooru(model_name):
|
||||
from modules.interrogate import deepbooru
|
||||
from modules.caption import deepbooru
|
||||
return deepbooru.tag(image, **kwargs)
|
||||
else:
|
||||
from modules.interrogate import waifudiffusion
|
||||
from modules.caption import waifudiffusion
|
||||
return waifudiffusion.tag(image, model_name=model_name, **kwargs)
|
||||
|
||||
|
||||
|
|
@ -72,8 +72,8 @@ def batch(model_name: str, **kwargs) -> str:
|
|||
Combined tag results
|
||||
"""
|
||||
if is_deepbooru(model_name):
|
||||
from modules.interrogate import deepbooru
|
||||
from modules.caption import deepbooru
|
||||
return deepbooru.batch(model_name=model_name, **kwargs)
|
||||
else:
|
||||
from modules.interrogate import waifudiffusion
|
||||
from modules.caption import waifudiffusion
|
||||
return waifudiffusion.batch(model_name=model_name, **kwargs)
|
||||
|
|
@ -9,11 +9,11 @@ import transformers
|
|||
import transformers.dynamic_module_utils
|
||||
from PIL import Image
|
||||
from modules import shared, devices, errors, model_quant, sd_models, sd_models_compile, ui_symbols
|
||||
from modules.interrogate import vqa_detection
|
||||
from modules.caption import vqa_detection
|
||||
|
||||
|
||||
# Debug logging - function-based to avoid circular import
|
||||
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
|
||||
debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None
|
||||
|
||||
def debug(*args, **kwargs):
|
||||
if debug_enabled:
|
||||
|
|
@ -265,7 +265,7 @@ def keep_think_block_open(text_prompt: str) -> str:
|
|||
while end_close < len(text_prompt) and text_prompt[end_close] in ('\r', '\n'):
|
||||
end_close += 1
|
||||
trimmed_prompt = text_prompt[:close_index] + text_prompt[end_close:]
|
||||
debug('VQA interrogate: keep_think_block_open applied to prompt segment near assistant reply')
|
||||
debug('VQA caption: keep_think_block_open applied to prompt segment near assistant reply')
|
||||
return trimmed_prompt
|
||||
|
||||
|
||||
|
|
@ -348,7 +348,7 @@ def get_keep_thinking():
|
|||
overrides = _get_overrides()
|
||||
if overrides.get('keep_thinking') is not None:
|
||||
return overrides['keep_thinking']
|
||||
return shared.opts.interrogate_vlm_keep_thinking
|
||||
return shared.opts.caption_vlm_keep_thinking
|
||||
|
||||
|
||||
def get_keep_prefill():
|
||||
|
|
@ -356,7 +356,7 @@ def get_keep_prefill():
|
|||
overrides = _get_overrides()
|
||||
if overrides.get('keep_prefill') is not None:
|
||||
return overrides['keep_prefill']
|
||||
return shared.opts.interrogate_vlm_keep_prefill
|
||||
return shared.opts.caption_vlm_keep_prefill
|
||||
|
||||
|
||||
def get_kwargs():
|
||||
|
|
@ -370,12 +370,12 @@ def get_kwargs():
|
|||
overrides = _get_overrides()
|
||||
|
||||
# Get base values from settings, apply overrides if provided
|
||||
max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.interrogate_vlm_max_length
|
||||
do_sample = overrides.get('do_sample') if overrides.get('do_sample') is not None else shared.opts.interrogate_vlm_do_sample
|
||||
num_beams = overrides.get('num_beams') if overrides.get('num_beams') is not None else shared.opts.interrogate_vlm_num_beams
|
||||
temperature = overrides.get('temperature') if overrides.get('temperature') is not None else shared.opts.interrogate_vlm_temperature
|
||||
top_k = overrides.get('top_k') if overrides.get('top_k') is not None else shared.opts.interrogate_vlm_top_k
|
||||
top_p = overrides.get('top_p') if overrides.get('top_p') is not None else shared.opts.interrogate_vlm_top_p
|
||||
max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.caption_vlm_max_length
|
||||
do_sample = overrides.get('do_sample') if overrides.get('do_sample') is not None else shared.opts.caption_vlm_do_sample
|
||||
num_beams = overrides.get('num_beams') if overrides.get('num_beams') is not None else shared.opts.caption_vlm_num_beams
|
||||
temperature = overrides.get('temperature') if overrides.get('temperature') is not None else shared.opts.caption_vlm_temperature
|
||||
top_k = overrides.get('top_k') if overrides.get('top_k') is not None else shared.opts.caption_vlm_top_k
|
||||
top_p = overrides.get('top_p') if overrides.get('top_p') is not None else shared.opts.caption_vlm_top_p
|
||||
|
||||
kwargs = {
|
||||
'max_new_tokens': max_tokens,
|
||||
|
|
@ -419,7 +419,7 @@ class VQA:
|
|||
|
||||
def load(self, model_name: str = None):
|
||||
"""Load VLM model into memory for the specified model name."""
|
||||
model_name = model_name or shared.opts.interrogate_vlm_model
|
||||
model_name = model_name or shared.opts.caption_vlm_model
|
||||
if not model_name:
|
||||
shared.log.warning('VQA load: no model specified')
|
||||
return
|
||||
|
|
@ -430,7 +430,7 @@ class VQA:
|
|||
|
||||
shared.log.debug(f'VQA load: pre-loading model="{model_name}" repo="{repo}"')
|
||||
|
||||
# Dispatch to appropriate loader (same logic as interrogate)
|
||||
# Dispatch to appropriate loader (same logic as caption)
|
||||
repo_lower = repo.lower()
|
||||
if 'qwen' in repo_lower or 'torii' in repo_lower or 'mimo' in repo_lower:
|
||||
self._load_qwen(repo)
|
||||
|
|
@ -459,22 +459,22 @@ class VQA:
|
|||
elif 'fastvlm' in repo_lower:
|
||||
self._load_fastvlm(repo)
|
||||
elif 'moondream3' in repo_lower:
|
||||
from modules.interrogate import moondream3
|
||||
from modules.caption import moondream3
|
||||
moondream3.load_model(repo)
|
||||
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
|
||||
return
|
||||
elif 'joytag' in repo_lower:
|
||||
from modules.interrogate import joytag
|
||||
from modules.caption import joytag
|
||||
joytag.load()
|
||||
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
|
||||
return
|
||||
elif 'joycaption' in repo_lower:
|
||||
from modules.interrogate import joycaption
|
||||
from modules.caption import joycaption
|
||||
joycaption.load(repo)
|
||||
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
|
||||
return
|
||||
elif 'deepseek' in repo_lower:
|
||||
from modules.interrogate import deepseek
|
||||
from modules.caption import deepseek
|
||||
deepseek.load(repo)
|
||||
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
|
||||
return
|
||||
|
|
@ -488,7 +488,7 @@ class VQA:
|
|||
def _load_fastvlm(self, repo: str):
|
||||
"""Load FastVLM model and tokenizer."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
quant_args = model_quant.create_config(module='LLM')
|
||||
self.model = None
|
||||
self.processor = transformers.AutoTokenizer.from_pretrained(repo, trust_remote_code=True, cache_dir=shared.opts.hfcache_dir)
|
||||
|
|
@ -503,7 +503,7 @@ class VQA:
|
|||
devices.torch_gc()
|
||||
|
||||
def _fastvlm(self, question: str, image: Image.Image, repo: str, model_name: str = None):
|
||||
debug(f'VQA interrogate: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}')
|
||||
debug(f'VQA caption: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}')
|
||||
self._load_fastvlm(repo)
|
||||
sd_models.move_model(self.model, devices.device)
|
||||
if len(question) < 2:
|
||||
|
|
@ -534,7 +534,7 @@ class VQA:
|
|||
def _load_qwen(self, repo: str):
|
||||
"""Load Qwen VL model and processor."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
self.model = None
|
||||
if 'Qwen3-VL' in repo or 'Qwen3VL' in repo:
|
||||
cls_name = transformers.Qwen3VLForConditionalGeneration
|
||||
|
|
@ -562,10 +562,10 @@ class VQA:
|
|||
sd_models.move_model(self.model, devices.device)
|
||||
# Get model class name for logging
|
||||
cls_name = self.model.__class__.__name__
|
||||
debug(f'VQA interrogate: handler=qwen model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
||||
debug(f'VQA caption: handler=qwen model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
||||
|
||||
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
||||
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
|
||||
system_prompt = system_prompt or shared.opts.caption_vlm_system
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -593,9 +593,9 @@ class VQA:
|
|||
use_prefill = len(prefill_text) > 0
|
||||
|
||||
if debug_enabled:
|
||||
debug(f'VQA interrogate: handler=qwen conversation_roles={[msg["role"] for msg in conversation]}')
|
||||
debug(f'VQA interrogate: handler=qwen full_conversation={truncate_b64_in_conversation(conversation)}')
|
||||
debug(f'VQA interrogate: handler=qwen is_thinking={is_thinking} thinking_mode={thinking_mode} prefill="{prefill_text}"')
|
||||
debug(f'VQA caption: handler=qwen conversation_roles={[msg["role"] for msg in conversation]}')
|
||||
debug(f'VQA caption: handler=qwen full_conversation={truncate_b64_in_conversation(conversation)}')
|
||||
debug(f'VQA caption: handler=qwen is_thinking={is_thinking} thinking_mode={thinking_mode} prefill="{prefill_text}"')
|
||||
|
||||
# Generate base prompt using template
|
||||
# Qwen-Thinking template automatically adds "<|im_start|>assistant\n<think>\n" when add_generation_prompt=True
|
||||
|
|
@ -605,7 +605,7 @@ class VQA:
|
|||
add_generation_prompt=True,
|
||||
)
|
||||
except (TypeError, ValueError) as e:
|
||||
debug(f'VQA interrogate: handler=qwen chat_template fallback add_generation_prompt=True: {e}')
|
||||
debug(f'VQA caption: handler=qwen chat_template fallback add_generation_prompt=True: {e}')
|
||||
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
|
||||
# Manually handle thinking tags and prefill
|
||||
|
|
@ -627,23 +627,23 @@ class VQA:
|
|||
text_prompt += prefill_text
|
||||
|
||||
if debug_enabled:
|
||||
debug(f'VQA interrogate: handler=qwen text_prompt="{text_prompt}"')
|
||||
debug(f'VQA caption: handler=qwen text_prompt="{text_prompt}"')
|
||||
inputs = self.processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
|
||||
inputs = inputs.to(devices.device, devices.dtype)
|
||||
gen_kwargs = get_kwargs()
|
||||
debug(f'VQA interrogate: handler=qwen generation_kwargs={gen_kwargs} input_ids_shape={inputs.input_ids.shape}')
|
||||
debug(f'VQA caption: handler=qwen generation_kwargs={gen_kwargs} input_ids_shape={inputs.input_ids.shape}')
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
**gen_kwargs,
|
||||
)
|
||||
debug(f'VQA interrogate: handler=qwen output_ids_shape={output_ids.shape}')
|
||||
debug(f'VQA caption: handler=qwen output_ids_shape={output_ids.shape}')
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):]
|
||||
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
|
||||
]
|
||||
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
if debug_enabled:
|
||||
debug(f'VQA interrogate: handler=qwen response_before_clean="{response}"')
|
||||
debug(f'VQA caption: handler=qwen response_before_clean="{response}"')
|
||||
# Clean up thinking tags
|
||||
# Note: <think> is in the prompt, not the response - only </think> appears in generated output
|
||||
if len(response) > 0:
|
||||
|
|
@ -672,7 +672,7 @@ class VQA:
|
|||
def _load_gemma(self, repo: str):
|
||||
"""Load Gemma 3 model and processor."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
self.model = None
|
||||
if '3n' in repo:
|
||||
cls = transformers.Gemma3nForConditionalGeneration # pylint: disable=no-member
|
||||
|
|
@ -696,10 +696,10 @@ class VQA:
|
|||
sd_models.move_model(self.model, devices.device)
|
||||
# Get model class name for logging
|
||||
cls_name = self.model.__class__.__name__
|
||||
debug(f'VQA interrogate: handler=gemma model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
||||
debug(f'VQA caption: handler=gemma model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
||||
|
||||
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
||||
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
|
||||
system_prompt = system_prompt or shared.opts.caption_vlm_system
|
||||
|
||||
system_content = []
|
||||
if system_prompt is not None and len(system_prompt) > 4:
|
||||
|
|
@ -726,14 +726,14 @@ class VQA:
|
|||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": prefill_text}],
|
||||
})
|
||||
debug(f'VQA interrogate: handler=gemma prefill="{prefill_text}"')
|
||||
debug(f'VQA caption: handler=gemma prefill="{prefill_text}"')
|
||||
else:
|
||||
debug('VQA interrogate: handler=gemma prefill disabled (empty), relying on add_generation_prompt')
|
||||
debug('VQA caption: handler=gemma prefill disabled (empty), relying on add_generation_prompt')
|
||||
if debug_enabled:
|
||||
debug(f'VQA interrogate: handler=gemma conversation_roles={[msg["role"] for msg in conversation]}')
|
||||
debug(f'VQA interrogate: handler=gemma full_conversation={truncate_b64_in_conversation(conversation)}')
|
||||
debug(f'VQA caption: handler=gemma conversation_roles={[msg["role"] for msg in conversation]}')
|
||||
debug(f'VQA caption: handler=gemma full_conversation={truncate_b64_in_conversation(conversation)}')
|
||||
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
|
||||
debug(f'VQA interrogate: handler=gemma template_mode={debug_prefill_mode}')
|
||||
debug(f'VQA caption: handler=gemma template_mode={debug_prefill_mode}')
|
||||
try:
|
||||
if use_prefill:
|
||||
text_prompt = self.processor.apply_chat_template(
|
||||
|
|
@ -749,7 +749,7 @@ class VQA:
|
|||
tokenize=False,
|
||||
)
|
||||
except (TypeError, ValueError) as e:
|
||||
debug(f'VQA interrogate: handler=gemma chat_template fallback add_generation_prompt=True: {e}')
|
||||
debug(f'VQA caption: handler=gemma chat_template fallback add_generation_prompt=True: {e}')
|
||||
text_prompt = self.processor.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
|
|
@ -758,7 +758,7 @@ class VQA:
|
|||
if use_prefill and use_thinking:
|
||||
text_prompt = keep_think_block_open(text_prompt)
|
||||
if debug_enabled:
|
||||
debug(f'VQA interrogate: handler=gemma text_prompt="{text_prompt}"')
|
||||
debug(f'VQA caption: handler=gemma text_prompt="{text_prompt}"')
|
||||
inputs = self.processor(
|
||||
text=[text_prompt],
|
||||
images=[image],
|
||||
|
|
@ -767,17 +767,17 @@ class VQA:
|
|||
).to(device=devices.device, dtype=devices.dtype)
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
gen_kwargs = get_kwargs()
|
||||
debug(f'VQA interrogate: handler=gemma generation_kwargs={gen_kwargs} input_len={input_len}')
|
||||
debug(f'VQA caption: handler=gemma generation_kwargs={gen_kwargs} input_len={input_len}')
|
||||
with devices.inference_context():
|
||||
generation = self.model.generate(
|
||||
**inputs,
|
||||
**gen_kwargs,
|
||||
)
|
||||
debug(f'VQA interrogate: handler=gemma output_ids_shape={generation.shape}')
|
||||
debug(f'VQA caption: handler=gemma output_ids_shape={generation.shape}')
|
||||
generation = generation[0][input_len:]
|
||||
response = self.processor.decode(generation, skip_special_tokens=True)
|
||||
if debug_enabled:
|
||||
debug(f'VQA interrogate: handler=gemma response_before_clean="{response}"')
|
||||
debug(f'VQA caption: handler=gemma response_before_clean="{response}"')
|
||||
|
||||
# Clean up thinking tags (if any remain)
|
||||
if get_keep_thinking():
|
||||
|
|
@ -798,7 +798,7 @@ class VQA:
|
|||
def _load_paligemma(self, repo: str):
|
||||
"""Load PaliGemma model and processor."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
self.processor = transformers.PaliGemmaProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
||||
self.model = None
|
||||
self.model = transformers.PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
|
|
@ -827,7 +827,7 @@ class VQA:
|
|||
def _load_ovis(self, repo: str):
|
||||
"""Load Ovis model (requires flash-attn)."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
self.model = None
|
||||
self.model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
repo,
|
||||
|
|
@ -843,7 +843,7 @@ class VQA:
|
|||
try:
|
||||
import flash_attn # pylint: disable=unused-import
|
||||
except Exception:
|
||||
shared.log.error(f'Interrogate: vlm="{repo}" flash-attn is not available')
|
||||
shared.log.error(f'Caption: vlm="{repo}" flash-attn is not available')
|
||||
return ''
|
||||
self._load_ovis(repo)
|
||||
sd_models.move_model(self.model, devices.device)
|
||||
|
|
@ -875,7 +875,7 @@ class VQA:
|
|||
def _load_smol(self, repo: str):
|
||||
"""Load SmolVLM model and processor."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
self.model = None
|
||||
quant_args = model_quant.create_config(module='LLM')
|
||||
self.model = transformers.AutoModelForVision2Seq.from_pretrained(
|
||||
|
|
@ -895,10 +895,10 @@ class VQA:
|
|||
sd_models.move_model(self.model, devices.device)
|
||||
# Get model class name for logging
|
||||
cls_name = self.model.__class__.__name__
|
||||
debug(f'VQA interrogate: handler=smol model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
||||
debug(f'VQA caption: handler=smol model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
||||
|
||||
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
||||
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
|
||||
system_prompt = system_prompt or shared.opts.caption_vlm_system
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
|
|
@ -924,14 +924,14 @@ class VQA:
|
|||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": prefill_text}],
|
||||
})
|
||||
debug(f'VQA interrogate: handler=smol prefill="{prefill_text}"')
|
||||
debug(f'VQA caption: handler=smol prefill="{prefill_text}"')
|
||||
else:
|
||||
debug('VQA interrogate: handler=smol prefill disabled (empty), relying on add_generation_prompt')
|
||||
debug('VQA caption: handler=smol prefill disabled (empty), relying on add_generation_prompt')
|
||||
if debug_enabled:
|
||||
debug(f'VQA interrogate: handler=smol conversation_roles={[msg["role"] for msg in conversation]}')
|
||||
debug(f'VQA interrogate: handler=smol full_conversation={truncate_b64_in_conversation(conversation)}')
|
||||
debug(f'VQA caption: handler=smol conversation_roles={[msg["role"] for msg in conversation]}')
|
||||
debug(f'VQA caption: handler=smol full_conversation={truncate_b64_in_conversation(conversation)}')
|
||||
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
|
||||
debug(f'VQA interrogate: handler=smol template_mode={debug_prefill_mode}')
|
||||
debug(f'VQA caption: handler=smol template_mode={debug_prefill_mode}')
|
||||
try:
|
||||
if use_prefill:
|
||||
text_prompt = self.processor.apply_chat_template(
|
||||
|
|
@ -942,24 +942,24 @@ class VQA:
|
|||
else:
|
||||
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
except (TypeError, ValueError) as e:
|
||||
debug(f'VQA interrogate: handler=smol chat_template fallback add_generation_prompt=True: {e}')
|
||||
debug(f'VQA caption: handler=smol chat_template fallback add_generation_prompt=True: {e}')
|
||||
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
if use_prefill and use_thinking:
|
||||
text_prompt = keep_think_block_open(text_prompt)
|
||||
if debug_enabled:
|
||||
debug(f'VQA interrogate: handler=smol text_prompt="{text_prompt}"')
|
||||
debug(f'VQA caption: handler=smol text_prompt="{text_prompt}"')
|
||||
inputs = self.processor(text=text_prompt, images=[image], padding=True, return_tensors="pt")
|
||||
inputs = inputs.to(devices.device, devices.dtype)
|
||||
gen_kwargs = get_kwargs()
|
||||
debug(f'VQA interrogate: handler=smol generation_kwargs={gen_kwargs}')
|
||||
debug(f'VQA caption: handler=smol generation_kwargs={gen_kwargs}')
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
**gen_kwargs,
|
||||
)
|
||||
debug(f'VQA interrogate: handler=smol output_ids_shape={output_ids.shape}')
|
||||
debug(f'VQA caption: handler=smol output_ids_shape={output_ids.shape}')
|
||||
response = self.processor.batch_decode(output_ids, skip_special_tokens=True)
|
||||
if debug_enabled:
|
||||
debug(f'VQA interrogate: handler=smol response_before_clean="{response}"')
|
||||
debug(f'VQA caption: handler=smol response_before_clean="{response}"')
|
||||
|
||||
# Clean up thinking tags
|
||||
if len(response) > 0:
|
||||
|
|
@ -981,7 +981,7 @@ class VQA:
|
|||
def _load_git(self, repo: str):
|
||||
"""Load Microsoft GIT model and processor."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
self.model = None
|
||||
self.model = transformers.GitForCausalLM.from_pretrained(
|
||||
repo,
|
||||
|
|
@ -1011,7 +1011,7 @@ class VQA:
|
|||
def _load_blip(self, repo: str):
|
||||
"""Load Salesforce BLIP model and processor."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
self.model = None
|
||||
self.model = transformers.BlipForQuestionAnswering.from_pretrained(
|
||||
repo,
|
||||
|
|
@ -1035,7 +1035,7 @@ class VQA:
|
|||
def _load_vilt(self, repo: str):
|
||||
"""Load ViLT model and processor."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
self.model = None
|
||||
self.model = transformers.ViltForQuestionAnswering.from_pretrained(
|
||||
repo,
|
||||
|
|
@ -1061,7 +1061,7 @@ class VQA:
|
|||
def _load_pix(self, repo: str):
|
||||
"""Load Pix2Struct model and processor."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
self.model = None
|
||||
self.model = transformers.Pix2StructForConditionalGeneration.from_pretrained(
|
||||
repo,
|
||||
|
|
@ -1087,7 +1087,7 @@ class VQA:
|
|||
def _load_moondream(self, repo: str):
|
||||
"""Load Moondream 2 model and tokenizer."""
|
||||
if self.model is None or self.loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo}"')
|
||||
self.model = None
|
||||
self.model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
repo,
|
||||
|
|
@ -1102,7 +1102,7 @@ class VQA:
|
|||
devices.torch_gc()
|
||||
|
||||
def _moondream(self, question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False):
|
||||
debug(f'VQA interrogate: handler=moondream model_name="{model_name}" repo="{repo}" question="{question}" thinking_mode={thinking_mode}')
|
||||
debug(f'VQA caption: handler=moondream model_name="{model_name}" repo="{repo}" question="{question}" thinking_mode={thinking_mode}')
|
||||
self._load_moondream(repo)
|
||||
sd_models.move_model(self.model, devices.device)
|
||||
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
||||
|
|
@ -1117,9 +1117,9 @@ class VQA:
|
|||
target = question[9:].strip() if question.lower().startswith('point at ') else ''
|
||||
if not target:
|
||||
return "Please specify an object to locate"
|
||||
debug(f'VQA interrogate: handler=moondream method=point target="{target}"')
|
||||
debug(f'VQA caption: handler=moondream method=point target="{target}"')
|
||||
result = self.model.point(image, target)
|
||||
debug(f'VQA interrogate: handler=moondream point_raw_result={result}')
|
||||
debug(f'VQA caption: handler=moondream point_raw_result={result}')
|
||||
points = vqa_detection.parse_points(result)
|
||||
if points:
|
||||
self.last_detection_data = {'points': points}
|
||||
|
|
@ -1129,35 +1129,35 @@ class VQA:
|
|||
target = question[7:].strip() if question.lower().startswith('detect ') else ''
|
||||
if not target:
|
||||
return "Please specify an object to detect"
|
||||
debug(f'VQA interrogate: handler=moondream method=detect target="{target}"')
|
||||
debug(f'VQA caption: handler=moondream method=detect target="{target}"')
|
||||
result = self.model.detect(image, target)
|
||||
debug(f'VQA interrogate: handler=moondream detect_raw_result={result}')
|
||||
debug(f'VQA caption: handler=moondream detect_raw_result={result}')
|
||||
detections = vqa_detection.parse_detections(result, target)
|
||||
if detections:
|
||||
self.last_detection_data = {'detections': detections}
|
||||
return vqa_detection.format_detections_text(detections, include_confidence=False)
|
||||
return "No objects detected"
|
||||
elif question == 'DETECT_GAZE' or question.lower() == 'detect gaze':
|
||||
debug('VQA interrogate: handler=moondream method=detect_gaze')
|
||||
debug('VQA caption: handler=moondream method=detect_gaze')
|
||||
faces = self.model.detect(image, "face")
|
||||
debug(f'VQA interrogate: handler=moondream detect_gaze faces={faces}')
|
||||
debug(f'VQA caption: handler=moondream detect_gaze faces={faces}')
|
||||
if faces.get('objects'):
|
||||
eye_x, eye_y = vqa_detection.calculate_eye_position(faces['objects'][0])
|
||||
result = self.model.detect_gaze(image, eye=(eye_x, eye_y))
|
||||
debug(f'VQA interrogate: handler=moondream detect_gaze result={result}')
|
||||
debug(f'VQA caption: handler=moondream detect_gaze result={result}')
|
||||
if result.get('gaze'):
|
||||
gaze = result['gaze']
|
||||
self.last_detection_data = {'points': [(gaze['x'], gaze['y'])]}
|
||||
return f"Gaze direction: ({gaze['x']:.3f}, {gaze['y']:.3f})"
|
||||
return "No face/gaze detected"
|
||||
else:
|
||||
debug(f'VQA interrogate: handler=moondream method=query question="{question}" reasoning={thinking_mode}')
|
||||
debug(f'VQA caption: handler=moondream method=query question="{question}" reasoning={thinking_mode}')
|
||||
result = self.model.query(image, question, reasoning=thinking_mode)
|
||||
response = result['answer']
|
||||
debug(f'VQA interrogate: handler=moondream query_result keys={list(result.keys()) if isinstance(result, dict) else "not dict"}')
|
||||
debug(f'VQA caption: handler=moondream query_result keys={list(result.keys()) if isinstance(result, dict) else "not dict"}')
|
||||
if thinking_mode and 'reasoning' in result:
|
||||
reasoning_text = result['reasoning'].get('text', '') if isinstance(result['reasoning'], dict) else str(result['reasoning'])
|
||||
debug(f'VQA interrogate: handler=moondream reasoning_text="{reasoning_text[:100]}..."')
|
||||
debug(f'VQA caption: handler=moondream reasoning_text="{reasoning_text[:100]}..."')
|
||||
if get_keep_thinking():
|
||||
response = f"Reasoning:\n{reasoning_text}\n\nAnswer:\n{response}"
|
||||
# When keep_thinking is False, just use the answer (reasoning is discarded)
|
||||
|
|
@ -1183,7 +1183,7 @@ class VQA:
|
|||
effective_revision = revision_from_repo
|
||||
|
||||
if self.model is None or self.loaded != cache_key:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo_name}" revision="{effective_revision}" path="{shared.opts.hfcache_dir}"')
|
||||
shared.log.debug(f'Caption load: vlm="{repo_name}" revision="{effective_revision}" path="{shared.opts.hfcache_dir}"')
|
||||
transformers.dynamic_module_utils.get_imports = get_imports
|
||||
self.model = None
|
||||
quant_args = model_quant.create_config(module='LLM')
|
||||
|
|
@ -1213,7 +1213,7 @@ class VQA:
|
|||
pixel_values = inputs['pixel_values'].to(devices.device, devices.dtype)
|
||||
# Florence-2 requires beam search, not sampling - sampling causes probability tensor errors
|
||||
overrides = _get_overrides()
|
||||
max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.interrogate_vlm_max_length
|
||||
max_tokens = overrides.get('max_tokens') if overrides.get('max_tokens') is not None else shared.opts.caption_vlm_max_length
|
||||
with devices.inference_context():
|
||||
generated_ids = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
|
|
@ -1263,9 +1263,9 @@ class VQA:
|
|||
response = return_dict["prediction"] # the text format answer
|
||||
return response
|
||||
|
||||
def interrogate(self, question: str = '', system_prompt: str = None, prompt: str = None, image: Image.Image = None, model_name: str = None, prefill: str = None, thinking_mode: bool = None, quiet: bool = False, generation_kwargs: dict = None) -> str:
|
||||
def caption(self, question: str = '', system_prompt: str = None, prompt: str = None, image: Image.Image = None, model_name: str = None, prefill: str = None, thinking_mode: bool = None, quiet: bool = False, generation_kwargs: dict = None) -> str:
|
||||
"""
|
||||
Main entry point for VQA interrogation. Returns string answer.
|
||||
Main entry point for VQA captioning. Returns string answer.
|
||||
Detection data stored in self.last_detection_data for annotated image creation.
|
||||
|
||||
Args:
|
||||
|
|
@ -1283,11 +1283,11 @@ class VQA:
|
|||
self.last_annotated_image = None
|
||||
self.last_detection_data = None
|
||||
self._generation_overrides = generation_kwargs # Set per-request overrides
|
||||
jobid = shared.state.begin('Interrogate LLM')
|
||||
jobid = shared.state.begin('Caption LLM')
|
||||
t0 = time.time()
|
||||
model_name = model_name or shared.opts.interrogate_vlm_model
|
||||
model_name = model_name or shared.opts.caption_vlm_model
|
||||
prefill = vlm_prefill if prefill is None else prefill # Use provided prefill when specified
|
||||
thinking_mode = shared.opts.interrogate_vlm_thinking_mode if thinking_mode is None else thinking_mode # Resolve from settings if not specified
|
||||
thinking_mode = shared.opts.caption_vlm_thinking_mode if thinking_mode is None else thinking_mode # Resolve from settings if not specified
|
||||
if isinstance(image, list):
|
||||
image = image[0] if len(image) > 0 else None
|
||||
if isinstance(image, dict) and 'name' in image:
|
||||
|
|
@ -1298,7 +1298,7 @@ class VQA:
|
|||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
if image is None:
|
||||
shared.log.error(f'VQA interrogate: model="{model_name}" error="No input image provided"')
|
||||
shared.log.error(f'VQA caption: model="{model_name}" error="No input image provided"')
|
||||
shared.state.end(jobid)
|
||||
return 'Error: No input image provided. Please upload or select an image.'
|
||||
|
||||
|
|
@ -1306,7 +1306,7 @@ class VQA:
|
|||
if question == "Use Prompt":
|
||||
# Use content from Prompt field directly - requires user input
|
||||
if not prompt or len(prompt.strip()) < 2:
|
||||
shared.log.error(f'VQA interrogate: model="{model_name}" error="Please enter a prompt"')
|
||||
shared.log.error(f'VQA caption: model="{model_name}" error="Please enter a prompt"')
|
||||
shared.state.end(jobid)
|
||||
return 'Error: Please enter a question or instruction in the Prompt field.'
|
||||
question = prompt
|
||||
|
|
@ -1316,7 +1316,7 @@ class VQA:
|
|||
if raw_mapping in ("POINT_MODE", "DETECT_MODE"):
|
||||
# These modes require user input in the prompt field
|
||||
if not prompt or len(prompt.strip()) < 2:
|
||||
shared.log.error(f'VQA interrogate: model="{model_name}" error="Please specify what to find in the prompt field"')
|
||||
shared.log.error(f'VQA caption: model="{model_name}" error="Please specify what to find in the prompt field"')
|
||||
shared.state.end(jobid)
|
||||
return 'Error: Please specify what to find in the prompt field (e.g., "the red car" or "faces").'
|
||||
# Convert friendly name to internal token (handles Point/Detect prefix)
|
||||
|
|
@ -1328,12 +1328,12 @@ class VQA:
|
|||
|
||||
try:
|
||||
if model_name is None:
|
||||
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected')
|
||||
shared.log.error(f'Caption: type=vlm model="{model_name}" no model selected')
|
||||
shared.state.end(jobid)
|
||||
return ''
|
||||
vqa_model = vlm_models.get(model_name, None)
|
||||
if vqa_model is None:
|
||||
shared.log.error(f'Interrogate: type=vlm model="{model_name}" unknown')
|
||||
shared.log.error(f'Caption: type=vlm model="{model_name}" unknown')
|
||||
shared.state.end(jobid)
|
||||
return ''
|
||||
|
||||
|
|
@ -1352,7 +1352,7 @@ class VQA:
|
|||
answer = self._pix(question, image, vqa_model, model_name)
|
||||
elif 'moondream3' in vqa_model.lower():
|
||||
handler = 'moondream3'
|
||||
from modules.interrogate import moondream3
|
||||
from modules.caption import moondream3
|
||||
answer = moondream3.predict(question, image, vqa_model, model_name, thinking_mode=thinking_mode)
|
||||
elif 'moondream2' in vqa_model.lower():
|
||||
handler = 'moondream'
|
||||
|
|
@ -1368,15 +1368,15 @@ class VQA:
|
|||
answer = self._smol(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
|
||||
elif 'joytag' in vqa_model.lower():
|
||||
handler = 'joytag'
|
||||
from modules.interrogate import joytag
|
||||
from modules.caption import joytag
|
||||
answer = joytag.predict(image)
|
||||
elif 'joycaption' in vqa_model.lower():
|
||||
handler = 'joycaption'
|
||||
from modules.interrogate import joycaption
|
||||
from modules.caption import joycaption
|
||||
answer = joycaption.predict(question, image, vqa_model)
|
||||
elif 'deepseek' in vqa_model.lower():
|
||||
handler = 'deepseek'
|
||||
from modules.interrogate import deepseek
|
||||
from modules.caption import deepseek
|
||||
answer = deepseek.predict(question, image, vqa_model)
|
||||
elif 'paligemma' in vqa_model.lower():
|
||||
handler = 'paligemma'
|
||||
|
|
@ -1399,7 +1399,7 @@ class VQA:
|
|||
errors.display(e, 'VQA')
|
||||
answer = 'error'
|
||||
|
||||
if shared.opts.interrogate_offload and self.model is not None:
|
||||
if shared.opts.caption_offload and self.model is not None:
|
||||
sd_models.move_model(self.model, devices.cpu, force=True)
|
||||
devices.torch_gc(force=True, reason='vqa')
|
||||
|
||||
|
|
@ -1412,16 +1412,17 @@ class VQA:
|
|||
points = self.last_detection_data.get('points', None)
|
||||
if detections or points:
|
||||
self.last_annotated_image = vqa_detection.draw_bounding_boxes(image, detections or [], points)
|
||||
debug(f'VQA interrogate: handler={handler} created annotated image detections={len(detections) if detections else 0} points={len(points) if points else 0}')
|
||||
debug(f'VQA caption: handler={handler} created annotated image detections={len(detections) if detections else 0} points={len(points) if points else 0}')
|
||||
|
||||
debug(f'VQA interrogate: handler={handler} response_after_clean="{answer}" has_annotation={self.last_annotated_image is not None}')
|
||||
debug(f'VQA caption: handler={handler} response_after_clean="{answer}" has_annotation={self.last_annotated_image is not None}')
|
||||
t1 = time.time()
|
||||
if not quiet:
|
||||
shared.log.debug(f'Interrogate: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}')
|
||||
shared.log.debug(f'Caption: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}')
|
||||
self._generation_overrides = None # Clear per-request overrides
|
||||
shared.state.end(jobid)
|
||||
return answer
|
||||
|
||||
|
||||
def batch(self, model_name, system_prompt, batch_files, batch_folder, batch_str, question, prompt, write, append, recursive, prefill=None, thinking_mode=False):
|
||||
class BatchWriter:
|
||||
def __init__(self, folder, mode='w'):
|
||||
|
|
@ -1450,15 +1451,15 @@ class VQA:
|
|||
from modules.files_cache import list_files
|
||||
files += list(list_files(batch_str, ext_filter=['.png', '.jpg', '.jpeg', '.webp', '.jxl'], recursive=recursive))
|
||||
if len(files) == 0:
|
||||
shared.log.warning('Interrogate batch: type=vlm no images')
|
||||
shared.log.warning('Caption batch: type=vlm no images')
|
||||
return ''
|
||||
jobid = shared.state.begin('Interrogate batch')
|
||||
jobid = shared.state.begin('Caption batch')
|
||||
prompts = []
|
||||
if write:
|
||||
mode = 'w' if not append else 'a'
|
||||
writer = BatchWriter(os.path.dirname(files[0]), mode=mode)
|
||||
orig_offload = shared.opts.interrogate_offload
|
||||
shared.opts.interrogate_offload = False
|
||||
orig_offload = shared.opts.caption_offload
|
||||
shared.opts.caption_offload = False
|
||||
import rich.progress as rp
|
||||
pbar = rp.Progress(rp.TextColumn('[cyan]Caption:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
||||
with pbar:
|
||||
|
|
@ -1469,7 +1470,7 @@ class VQA:
|
|||
if shared.state.interrupted:
|
||||
break
|
||||
img = Image.open(file)
|
||||
caption = self.interrogate(question, system_prompt, prompt, img, model_name, prefill, thinking_mode, quiet=True)
|
||||
caption = self.caption(question, system_prompt, prompt, img, model_name, prefill, thinking_mode, quiet=True)
|
||||
# Save annotated image if available
|
||||
if self.last_annotated_image and write:
|
||||
annotated_path = os.path.splitext(file)[0] + "_annotated.png"
|
||||
|
|
@ -1478,10 +1479,10 @@ class VQA:
|
|||
if write:
|
||||
writer.add(file, caption)
|
||||
except Exception as e:
|
||||
shared.log.error(f'Interrogate batch: {e}')
|
||||
shared.log.error(f'Caption batch: {e}')
|
||||
if write:
|
||||
writer.close()
|
||||
shared.opts.interrogate_offload = orig_offload
|
||||
shared.opts.caption_offload = orig_offload
|
||||
shared.state.end(jobid)
|
||||
return '\n\n'.join(prompts)
|
||||
|
||||
|
|
@ -1499,8 +1500,9 @@ def get_instance() -> VQA:
|
|||
|
||||
|
||||
# Backwards-compatible module-level functions
|
||||
def interrogate(*args, **kwargs):
|
||||
return get_instance().interrogate(*args, **kwargs)
|
||||
def caption(*args, **kwargs):
|
||||
return get_instance().caption(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
def unload_model():
|
||||
|
|
@ -10,8 +10,8 @@ from PIL import Image
|
|||
from modules import shared, devices, errors
|
||||
|
||||
|
||||
# Debug logging - enable with SD_INTERROGATE_DEBUG environment variable
|
||||
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
|
||||
# Debug logging - enable with SD_CAPTION_DEBUG environment variable
|
||||
debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None
|
||||
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
|
@ -405,7 +405,7 @@ def tag(image: Image.Image, model_name: str = None, **kwargs) -> str:
|
|||
result = tagger.predict(image, **kwargs)
|
||||
shared.log.debug(f'WaifuDiffusion: complete time={time.time()-t0:.2f} tags={len(result.split(", ")) if result else 0}')
|
||||
# Offload model if setting enabled
|
||||
if shared.opts.interrogate_offload:
|
||||
if shared.opts.caption_offload:
|
||||
tagger.unload()
|
||||
except Exception as e:
|
||||
result = f"Exception {type(e)}"
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
import time
|
||||
from PIL import Image
|
||||
from modules import shared
|
||||
|
||||
|
||||
def interrogate(image):
|
||||
if isinstance(image, list):
|
||||
image = image[0] if len(image) > 0 else None
|
||||
if isinstance(image, dict) and 'name' in image:
|
||||
image = Image.open(image['name'])
|
||||
if image is None:
|
||||
shared.log.error('Interrogate: no image provided')
|
||||
return ''
|
||||
t0 = time.time()
|
||||
if shared.opts.interrogate_default_type == 'OpenCLiP':
|
||||
shared.log.info(f'Interrogate: type={shared.opts.interrogate_default_type} clip="{shared.opts.interrogate_clip_model}" blip="{shared.opts.interrogate_blip_model}" mode="{shared.opts.interrogate_clip_mode}"')
|
||||
from modules.interrogate import openclip
|
||||
openclip.load_interrogator(clip_model=shared.opts.interrogate_clip_model, blip_model=shared.opts.interrogate_blip_model)
|
||||
openclip.update_interrogate_params()
|
||||
prompt = openclip.interrogate(image, mode=shared.opts.interrogate_clip_mode)
|
||||
shared.log.debug(f'Interrogate: time={time.time()-t0:.2f} answer="{prompt}"')
|
||||
return prompt
|
||||
elif shared.opts.interrogate_default_type == 'Tagger':
|
||||
shared.log.info(f'Interrogate: type={shared.opts.interrogate_default_type} model="{shared.opts.waifudiffusion_model}"')
|
||||
from modules.interrogate import tagger
|
||||
prompt = tagger.tag(
|
||||
image=image,
|
||||
model_name=shared.opts.waifudiffusion_model,
|
||||
general_threshold=shared.opts.tagger_threshold,
|
||||
character_threshold=shared.opts.waifudiffusion_character_threshold,
|
||||
include_rating=shared.opts.tagger_include_rating,
|
||||
exclude_tags=shared.opts.tagger_exclude_tags,
|
||||
max_tags=shared.opts.tagger_max_tags,
|
||||
sort_alpha=shared.opts.tagger_sort_alpha,
|
||||
use_spaces=shared.opts.tagger_use_spaces,
|
||||
escape_brackets=shared.opts.tagger_escape_brackets,
|
||||
)
|
||||
shared.log.debug(f'Interrogate: time={time.time()-t0:.2f} answer="{prompt}"')
|
||||
return prompt
|
||||
elif shared.opts.interrogate_default_type == 'VLM':
|
||||
shared.log.info(f'Interrogate: type={shared.opts.interrogate_default_type} vlm="{shared.opts.interrogate_vlm_model}" prompt="{shared.opts.interrogate_vlm_prompt}"')
|
||||
from modules.interrogate import vqa
|
||||
prompt = vqa.interrogate(image=image, model_name=shared.opts.interrogate_vlm_model, question=shared.opts.interrogate_vlm_prompt, prompt=None, system_prompt=shared.opts.interrogate_vlm_system)
|
||||
shared.log.debug(f'Interrogate: time={time.time()-t0:.2f} answer="{prompt}"')
|
||||
return prompt
|
||||
else:
|
||||
shared.log.error(f'Interrogate: type="{shared.opts.interrogate_default_type}" unknown')
|
||||
return ''
|
||||
Loading…
Reference in New Issue