import os import time from collections import namedtuple import threading import re import gradio as gr from PIL import Image from modules import devices, shared, errors from modules.logger import log, console debug_enabled = os.environ.get('SD_CAPTION_DEBUG', None) is not None debug_log = log.trace if debug_enabled else lambda *args, **kwargs: None # Per-request overrides for API calls _clip_overrides = None def get_clip_setting(name): """Get CLIP setting with per-request override support. Args: name: Setting name without 'caption_openclip_' prefix (e.g., 'min_flavors', 'max_length') Returns: Override value if set, otherwise the value from shared.opts """ if _clip_overrides is not None: value = _clip_overrides.get(name) if value is not None: return value return getattr(shared.opts, f'caption_openclip_{name}') def _apply_blip2_fix(model, processor): """Apply compatibility fix for BLIP2 models with newer transformers versions.""" from transformers import AddedToken if not hasattr(model.config, 'num_query_tokens'): return processor.num_query_tokens = model.config.num_query_tokens image_token = AddedToken("", normalized=False, special=True) processor.tokenizer.add_tokens([image_token], special_tokens=True) model.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=64) model.config.image_token_index = len(processor.tokenizer) - 1 debug_log(f'CLIP load: applied BLIP2 tokenizer fix num_query_tokens={model.config.num_query_tokens}') caption_models = { 'blip-base': 'Salesforce/blip-image-captioning-base', 'blip-large': 'Salesforce/blip-image-captioning-large', 'blip2-opt-2.7b': 'Salesforce/blip2-opt-2.7b-coco', 'blip2-opt-6.7b': 'Salesforce/blip2-opt-6.7b', 'blip2-flip-t5-xl': 'Salesforce/blip2-flan-t5-xl', 'blip2-flip-t5-xxl': 'Salesforce/blip2-flan-t5-xxl', } caption_types = [ 'best', 'fast', 'classic', 'caption', 'negative', ] clip_models = [] ci = None blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") load_lock = threading.Lock() class BatchWriter: def __init__(self, folder, mode='w'): self.folder = folder self.csv = None self.file = None self.mode = mode def add(self, file, prompt): txt_file = os.path.splitext(file)[0] + ".txt" if self.mode == 'a': prompt = '\n' + prompt with open(os.path.join(self.folder, txt_file), self.mode, encoding='utf-8') as f: f.write(prompt) def close(self): if self.file is not None: self.file.close() def update_caption_params(): if ci is not None: ci.config.caption_max_length = get_clip_setting('max_length') ci.config.chunk_size = get_clip_setting('chunk_size') ci.config.flavor_intermediate_count = get_clip_setting('flavor_count') ci.clip_offload = shared.opts.caption_offload ci.caption_offload = shared.opts.caption_offload def get_clip_models(): return clip_models def refresh_clip_models(): from installer import install install('open-clip-torch', no_deps=True, quiet=True) global clip_models # pylint: disable=global-statement import open_clip models = sorted(open_clip.list_pretrained()) log.debug(f'Caption: pkg=openclip version={open_clip.__version__} models={len(models)}') clip_models = ['/'.join(x) for x in models] return clip_models def _load_blip_model(blip_model: str, device): """Pre-load BLIP caption model with cache_dir so downloads go to hfcache_dir.""" import transformers model_path = caption_models.get(blip_model, blip_model) cache_dir = shared.opts.clip_models_path dtype = devices.dtype if blip_model.startswith('git-'): caption_model = transformers.AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, cache_dir=cache_dir) elif blip_model.startswith('blip2-'): caption_model = transformers.Blip2ForConditionalGeneration.from_pretrained(model_path, torch_dtype=dtype, cache_dir=cache_dir) else: caption_model = transformers.BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=dtype, cache_dir=cache_dir) caption_processor = transformers.AutoProcessor.from_pretrained(model_path, cache_dir=cache_dir) caption_model.eval() if not shared.opts.caption_offload: caption_model = caption_model.to(device) return caption_model, caption_processor def load_captioner(clip_model, blip_model): from installer import install install('clip_interrogator==0.6.0') import clip_interrogator clip_interrogator.clip_interrogator.CAPTION_MODELS = caption_models global ci # pylint: disable=global-statement if ci is None: t0 = time.time() device = devices.get_optimal_device() cache_path = shared.opts.clip_models_path log.info(f'CLIP load: clip="{clip_model}" blip="{blip_model}" device={device}') debug_log(f'CLIP load: cache_path="{cache_path}" max_length={shared.opts.caption_openclip_max_length} chunk_size={shared.opts.caption_openclip_chunk_size} flavor_count={shared.opts.caption_openclip_flavor_count} offload={shared.opts.caption_offload}') caption_model, caption_processor = _load_blip_model(blip_model, device) captioner_config = clip_interrogator.Config( device=device, cache_path=cache_path, clip_model_path=cache_path, clip_model_name=clip_model, caption_model_name=blip_model, quiet=True, caption_max_length=shared.opts.caption_openclip_max_length, chunk_size=shared.opts.caption_openclip_chunk_size, flavor_intermediate_count=shared.opts.caption_openclip_flavor_count, clip_offload=shared.opts.caption_offload, caption_offload=shared.opts.caption_offload, ) captioner_config.caption_model = caption_model captioner_config.caption_processor = caption_processor ci = clip_interrogator.Interrogator(captioner_config) if blip_model.startswith('blip2-'): _apply_blip2_fix(ci.caption_model, ci.caption_processor) log.debug(f'CLIP load: time={time.time()-t0:.2f}') elif clip_model != ci.config.clip_model_name or blip_model != ci.config.caption_model_name: t0 = time.time() if clip_model != ci.config.clip_model_name: log.info(f'CLIP load: clip="{clip_model}" reloading') debug_log(f'CLIP load: previous clip="{ci.config.clip_model_name}"') ci.config.clip_model_name = clip_model ci.config.clip_model = None ci.load_clip_model() ci.clip_offloaded = True # Reset flag so _prepare_clip() will move model to device if blip_model != ci.config.caption_model_name: log.info(f'CLIP load: blip="{blip_model}" reloading') debug_log(f'CLIP load: previous blip="{ci.config.caption_model_name}"') ci.config.caption_model_name = blip_model caption_model, caption_processor = _load_blip_model(blip_model, ci.device) ci.caption_model = caption_model ci.caption_processor = caption_processor ci.caption_offloaded = True # Reset flag so _prepare_caption() will move model to device if blip_model.startswith('blip2-'): _apply_blip2_fix(ci.caption_model, ci.caption_processor) log.debug(f'CLIP load: time={time.time()-t0:.2f}') else: debug_log(f'CLIP: models already loaded clip="{clip_model}" blip="{blip_model}"') def unload_clip_model(): if ci is not None and shared.opts.caption_offload: log.debug('CLIP unload: offloading models to CPU') # Direct .to() instead of sd_models.move_model — models are from clip_interrogator, not transformers if ci.caption_model is not None and hasattr(ci.caption_model, 'to'): ci.caption_model.to(devices.cpu) if ci.clip_model is not None and hasattr(ci.clip_model, 'to'): ci.clip_model.to(devices.cpu) ci.caption_offloaded = True ci.clip_offloaded = True devices.torch_gc() debug_log('CLIP unload: complete') def caption(image, mode, base_caption=None): if isinstance(image, list): image = image[0] if len(image) > 0 else None if isinstance(image, dict) and 'name' in image: image = Image.open(image['name']) if image is None: return '' image = image.convert("RGB") t0 = time.time() min_flavors = get_clip_setting('min_flavors') max_flavors = get_clip_setting('max_flavors') debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={base_caption is not None} min_flavors={min_flavors} max_flavors={max_flavors}') # NOTE: Method names like .interrogate(), .interrogate_classic(), etc. come from the external # clip-interrogator library (https://github.com/pharmapsychotic/clip-interrogator) and cannot be renamed. if mode == 'best': prompt = ci.interrogate(image, caption=base_caption, min_flavors=min_flavors, max_flavors=max_flavors) elif mode == 'caption': prompt = ci.generate_caption(image) if base_caption is None else base_caption elif mode == 'classic': prompt = ci.interrogate_classic(image, caption=base_caption, max_flavors=max_flavors) elif mode == 'fast': prompt = ci.interrogate_fast(image, caption=base_caption, max_flavors=max_flavors) elif mode == 'negative': prompt = ci.interrogate_negative(image, max_flavors=max_flavors) else: raise RuntimeError(f"Unknown mode {mode}") debug_log(f'CLIP: mode="{mode}" time={time.time()-t0:.2f} result="{prompt[:100]}..."' if len(prompt) > 100 else f'CLIP: mode="{mode}" time={time.time()-t0:.2f} result="{prompt}"') return prompt def caption_image(image, clip_model, blip_model, mode, overrides=None): global _clip_overrides # pylint: disable=global-statement jobid = shared.state.begin('Caption CLiP') t0 = time.time() log.info(f'CLIP: mode="{mode}" clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}') if overrides: debug_log(f'CLIP: overrides={overrides}') try: # Set per-request overrides _clip_overrides = overrides if shared.sd_loaded: from modules.sd_models import apply_balanced_offload # prevent circular import apply_balanced_offload(shared.sd_model) debug_log('CLIP: applied balanced offload to sd_model') load_captioner(clip_model, blip_model) # Apply overrides to loaded captioner update_caption_params() image = image.convert('RGB') prompt = caption(image, mode) if shared.opts.caption_offload: unload_clip_model() devices.torch_gc() log.debug(f'CLIP: complete time={time.time()-t0:.2f}') except Exception as e: prompt = f"Exception {type(e)}" log.error(f'CLIP: {e}') errors.display(e, 'Caption') finally: # Clear per-request overrides _clip_overrides = None shared.state.end(jobid) return prompt def caption_batch(batch_files, batch_folder, batch_str, clip_model, blip_model, mode, write, append, recursive): files = [] if batch_files is not None: files += [f.name for f in batch_files] if batch_folder is not None: files += [f.name for f in batch_folder] if batch_str is not None and len(batch_str) > 0 and os.path.exists(batch_str) and os.path.isdir(batch_str): from modules.files_cache import list_files files += list(list_files(batch_str, ext_filter=['.png', '.jpg', '.jpeg', '.webp', '.jxl'], recursive=recursive)) if len(files) == 0: log.warning('CLIP batch: no images found') return '' t0 = time.time() log.info(f'CLIP batch: mode="{mode}" images={len(files)} clip="{clip_model}" blip="{blip_model}" write={write} append={append}') debug_log(f'CLIP batch: recursive={recursive} files={files[:5]}{"..." if len(files) > 5 else ""}') jobid = shared.state.begin('Caption batch') prompts = [] load_captioner(clip_model, blip_model) if write: file_mode = 'w' if not append else 'a' writer = BatchWriter(os.path.dirname(files[0]), mode=file_mode) debug_log(f'CLIP batch: writing to "{os.path.dirname(files[0])}" mode="{file_mode}"') import rich.progress as rp pbar = rp.Progress(rp.TextColumn('[cyan]Caption:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=console) with pbar: task = pbar.add_task(total=len(files), description='starting...') for file in files: pbar.update(task, advance=1, description=file) try: if shared.state.interrupted: log.info('CLIP batch: interrupted') break image = Image.open(file).convert('RGB') prompt = caption(image, mode) prompts.append(prompt) if write: writer.add(file, prompt) except OSError as e: log.error(f'CLIP batch: file="{file}" error={e}') if write: writer.close() ci.config.quiet = False unload_clip_model() shared.state.end(jobid) log.info(f'CLIP batch: complete images={len(prompts)} time={time.time()-t0:.2f}') return '\n\n'.join(prompts) def analyze_image(image, clip_model, blip_model): t0 = time.time() log.info(f'CLIP analyze: clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}') load_captioner(clip_model, blip_model) image = image.convert('RGB') image_features = ci.image_to_features(image) debug_log(f'CLIP analyze: features shape={image_features.shape if hasattr(image_features, "shape") else "unknown"}') top_mediums = ci.mediums.rank(image_features, 5) top_artists = ci.artists.rank(image_features, 5) top_movements = ci.movements.rank(image_features, 5) top_trendings = ci.trendings.rank(image_features, 5) top_flavors = ci.flavors.rank(image_features, 5) medium_ranks = dict(sorted(zip(top_mediums, ci.similarities(image_features, top_mediums), strict=False), key=lambda x: x[1], reverse=True)) artist_ranks = dict(sorted(zip(top_artists, ci.similarities(image_features, top_artists), strict=False), key=lambda x: x[1], reverse=True)) movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements), strict=False), key=lambda x: x[1], reverse=True)) trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings), strict=False), key=lambda x: x[1], reverse=True)) flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors), strict=False), key=lambda x: x[1], reverse=True)) log.debug(f'CLIP analyze: complete time={time.time()-t0:.2f}') # Format labels as text def format_category(name, ranks): lines = [f"{name}:"] for item, score in ranks.items(): lines.append(f" • {item} - {score*100:.1f}%") return '\n'.join(lines) formatted_text = '\n\n'.join([ format_category("Medium", medium_ranks), format_category("Artist", artist_ranks), format_category("Movement", movement_ranks), format_category("Trending", trending_ranks), format_category("Flavor", flavor_ranks), ]) return [ gr.update(value=medium_ranks, visible=True), gr.update(value=artist_ranks, visible=True), gr.update(value=movement_ranks, visible=True), gr.update(value=trending_ranks, visible=True), gr.update(value=flavor_ranks, visible=True), gr.update(value=formatted_text, visible=True), # New text output for the textbox ]