diff --git a/CHANGELOG.md b/CHANGELOG.md index 17ba58a3e..d78d96980 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ - **Changes** - Update all core requirements - Support Remote VAE with *Omnigen, Lumina 2 and PixArt* + - Enable quantization for captioning: *Gemma, Qwen, SMOL, Florence, JoyCaption* - Add `--trace` command line param that enables trace logging - Use Diffusers version of *OmniGen* - Control move global settings to control elements -> control settings tab diff --git a/modules/interrogate/joycaption.py b/modules/interrogate/joycaption.py index 4941f7899..dc5e07213 100644 --- a/modules/interrogate/joycaption.py +++ b/modules/interrogate/joycaption.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import torch from transformers import AutoProcessor, LlavaForConditionalGeneration -from modules import shared, devices +from modules import shared, devices, sd_models, model_quant """ @@ -66,8 +66,16 @@ def predict(question: str, image, vqa_model: str = None) -> str: llava_model = None if llava_model is None: shared.log.info(f'Interrogate: type=vlm model="JoyCaption" {str(opts)}') + processor = AutoProcessor.from_pretrained(opts.repo) - llava_model = LlavaForConditionalGeneration.from_pretrained(opts.repo, torch_dtype=devices.dtype, device_map="auto", cache_dir=shared.opts.hfcache_dir) + quant_args = model_quant.create_config(module='LLM') + llava_model = LlavaForConditionalGeneration.from_pretrained( + opts.repo, + torch_dtype=devices.dtype, + device_map="auto", + cache_dir=shared.opts.hfcache_dir, + **quant_args, + ) llava_model.eval() if len(question) < 2: @@ -80,7 +88,7 @@ def predict(question: str, image, vqa_model: str = None) -> str: convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to(devices.device) # Process the inputs inputs['pixel_values'] = inputs['pixel_values'].to(devices.dtype) - llava_model = llava_model.to(devices.device) + sd_models.move_model(llava_model, devices.device) with devices.inference_context(): generate_ids = llava_model.generate( # Generate the captions **inputs, @@ -97,6 +105,6 @@ def predict(question: str, image, vqa_model: str = None) -> str: )[0] generate_ids = generate_ids[inputs['input_ids'].shape[1]:] # Trim off the prompt caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Decode the caption - llava_model = llava_model.to(devices.cpu) + sd_models.move_model(llava_model, devices.cpu, force=True) caption = caption.replace('\n\n', '\n').strip() return caption diff --git a/modules/interrogate/vqa.py b/modules/interrogate/vqa.py index f11065617..d2aeab61a 100644 --- a/modules/interrogate/vqa.py +++ b/modules/interrogate/vqa.py @@ -7,12 +7,13 @@ import torch import transformers import transformers.dynamic_module_utils from PIL import Image -from modules import shared, devices, errors, sd_models +from modules import shared, devices, errors, sd_models, model_quant processor = None model = None loaded: str = None +quant_args = {} vlm_models = { "Microsoft Florence 2 Base": "microsoft/Florence-2-base", # 0.5GB "Microsoft Florence 2 Large": "microsoft/Florence-2-large", # 1.5GB @@ -121,9 +122,10 @@ def qwen(question: str, image: Image.Image, repo: str = None, system_prompt: str model = None model = transformers.Qwen2VLForConditionalGeneration.from_pretrained( repo, - cache_dir=shared.opts.hfcache_dir + torch_dtype=devices.dtype, + cache_dir=shared.opts.hfcache_dir, + **quant_args, ) - model = model.to(devices.device, devices.dtype) processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) loaded = repo devices.torch_gc() @@ -166,8 +168,12 @@ def gemma(question: str, image: Image.Image, repo: str = None, system_prompt: st if model is None or loaded != repo: shared.log.debug(f'Interrogate load: vlm="{repo}"') model = None - model = transformers.Gemma3ForConditionalGeneration.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) - model = model.to(devices.device, devices.dtype) + model = transformers.Gemma3ForConditionalGeneration.from_pretrained( + repo, + torch_dtype=devices.dtype, + cache_dir=shared.opts.hfcache_dir, + **quant_args, + ) processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) loaded = repo devices.torch_gc() @@ -217,7 +223,6 @@ def paligemma(question: str, image: Image.Image, repo: str = None): cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype, ) - model = model.to(devices.device, devices.dtype) loaded = repo devices.torch_gc() sd_models.move_model(model, devices.device) @@ -251,7 +256,6 @@ def ovis(question: str, image: Image.Image, repo: str = None): trust_remote_code=True, cache_dir=shared.opts.hfcache_dir, ) - model = model.to(devices.device, devices.dtype) loaded = repo devices.torch_gc() sd_models.move_model(model, devices.device) @@ -292,8 +296,8 @@ def smol(question: str, image: Image.Image, repo: str = None, system_prompt: str cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype, _attn_implementation="eager", + **quant_args, ) - model.to(devices.device, devices.dtype) processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) loaded = repo devices.torch_gc() @@ -331,9 +335,9 @@ def git(question: str, image: Image.Image, repo: str = None): model = None model = transformers.GitForCausalLM.from_pretrained( repo, + torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, ) - model.to(devices.device, devices.dtype) processor = transformers.GitProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) loaded = repo devices.torch_gc() @@ -359,9 +363,9 @@ def blip(question: str, image: Image.Image, repo: str = None): model = None model = transformers.BlipForQuestionAnswering.from_pretrained( repo, + torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, ) - model.to(devices.device, devices.dtype) processor = transformers.BlipProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) loaded = repo devices.torch_gc() @@ -381,9 +385,9 @@ def vilt(question: str, image: Image.Image, repo: str = None): model = None model = transformers.ViltForQuestionAnswering.from_pretrained( repo, + torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, ) - model.to(devices.device) processor = transformers.ViltProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) loaded = repo devices.torch_gc() @@ -405,9 +409,9 @@ def pix(question: str, image: Image.Image, repo: str = None): model = None model = transformers.Pix2StructForConditionalGeneration.from_pretrained( repo, + torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, ) - model.to(devices.device) processor = transformers.Pix2StructProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) loaded = repo devices.torch_gc() @@ -431,11 +435,11 @@ def moondream(question: str, image: Image.Image, repo: str = None): repo, revision="2025-06-21", trust_remote_code=True, - cache_dir=shared.opts.hfcache_dir + torch_dtype=devices.dtype, + cache_dir=shared.opts.hfcache_dir, ) processor = transformers.AutoTokenizer.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) loaded = repo - model.to(devices.device, devices.dtype) model.eval() devices.torch_gc() sd_models.move_model(model, devices.device) @@ -475,12 +479,13 @@ def florence(question: str, image: Image.Image, repo: str = None, revision: str repo, trust_remote_code=True, revision=revision, + torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, + **quant_args, ) processor = transformers.AutoProcessor.from_pretrained(repo, trust_remote_code=True, revision=revision, cache_dir=shared.opts.hfcache_dir) transformers.dynamic_module_utils.get_imports = _get_imports loaded = repo - model.to(devices.device, devices.dtype) model.eval() devices.torch_gc() sd_models.move_model(model, devices.device) @@ -512,7 +517,6 @@ def sa2(question: str, image: Image.Image, repo: str = None): low_cpu_mem_usage=True, use_flash_attn=False, trust_remote_code=True) - model = model.to(devices.device, devices.dtype) model = model.eval() processor = transformers.AutoTokenizer.from_pretrained( repo, @@ -539,9 +543,11 @@ def sa2(question: str, image: Image.Image, repo: str = None): def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:Image.Image=None, model_name:str=None, quiet:bool=False): + global quant_args # pylint: disable=global-statement if not quiet: shared.state.begin('Interrogate') t0 = time.time() + quant_args = model_quant.create_config(module='LLM') model_name = model_name or shared.opts.interrogate_vlm_model if isinstance(image, list): image = image[0] if len(image) > 0 else None @@ -560,8 +566,10 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image: if shared.native and shared.sd_loaded: from modules.sd_models import apply_balanced_offload # prevent circular import apply_balanced_offload(shared.sd_model) + from modules import modelloader modelloader.hf_login() + try: if model_name is None: shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected') @@ -573,6 +581,7 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image: if image is None: shared.log.error(f'Interrogate: type=vlm model="{model_name}" no input image') return '' + if 'git' in vqa_model.lower(): answer = git(question, image, vqa_model) elif 'vilt' in vqa_model.lower(): @@ -611,9 +620,10 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image: except Exception as e: errors.display(e, 'VQA') answer = 'error' + if shared.opts.interrogate_offload and model is not None: - sd_models.move_model(model, devices.cpu) - devices.torch_gc() + sd_models.move_model(model, devices.cpu, force=True) + devices.torch_gc(force=True) answer = clean(answer, question) t1 = time.time() if not quiet: diff --git a/wiki b/wiki index b95139063..45c389a37 160000 --- a/wiki +++ b/wiki @@ -1 +1 @@ -Subproject commit b95139063c469ab11a379804ddb606ccfaf6c5c6 +Subproject commit 45c389a3747cb1caceb4035535041c400c29cad2