enable quants for vlm-captioning

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4005/head
Vladimir Mandic 2025-06-29 11:48:02 -04:00
parent 99f323d105
commit 1b4e1ff0ef
4 changed files with 42 additions and 23 deletions

View File

@ -36,6 +36,7 @@
- **Changes** - **Changes**
- Update all core requirements - Update all core requirements
- Support Remote VAE with *Omnigen, Lumina 2 and PixArt* - Support Remote VAE with *Omnigen, Lumina 2 and PixArt*
- Enable quantization for captioning: *Gemma, Qwen, SMOL, Florence, JoyCaption*
- Add `--trace` command line param that enables trace logging - Add `--trace` command line param that enables trace logging
- Use Diffusers version of *OmniGen* - Use Diffusers version of *OmniGen*
- Control move global settings to control elements -> control settings tab - Control move global settings to control elements -> control settings tab

View File

@ -3,7 +3,7 @@
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration from transformers import AutoProcessor, LlavaForConditionalGeneration
from modules import shared, devices from modules import shared, devices, sd_models, model_quant
""" """
@ -66,8 +66,16 @@ def predict(question: str, image, vqa_model: str = None) -> str:
llava_model = None llava_model = None
if llava_model is None: if llava_model is None:
shared.log.info(f'Interrogate: type=vlm model="JoyCaption" {str(opts)}') shared.log.info(f'Interrogate: type=vlm model="JoyCaption" {str(opts)}')
processor = AutoProcessor.from_pretrained(opts.repo) processor = AutoProcessor.from_pretrained(opts.repo)
llava_model = LlavaForConditionalGeneration.from_pretrained(opts.repo, torch_dtype=devices.dtype, device_map="auto", cache_dir=shared.opts.hfcache_dir) quant_args = model_quant.create_config(module='LLM')
llava_model = LlavaForConditionalGeneration.from_pretrained(
opts.repo,
torch_dtype=devices.dtype,
device_map="auto",
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
llava_model.eval() llava_model.eval()
if len(question) < 2: if len(question) < 2:
@ -80,7 +88,7 @@ def predict(question: str, image, vqa_model: str = None) -> str:
convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to(devices.device) # Process the inputs inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to(devices.device) # Process the inputs
inputs['pixel_values'] = inputs['pixel_values'].to(devices.dtype) inputs['pixel_values'] = inputs['pixel_values'].to(devices.dtype)
llava_model = llava_model.to(devices.device) sd_models.move_model(llava_model, devices.device)
with devices.inference_context(): with devices.inference_context():
generate_ids = llava_model.generate( # Generate the captions generate_ids = llava_model.generate( # Generate the captions
**inputs, **inputs,
@ -97,6 +105,6 @@ def predict(question: str, image, vqa_model: str = None) -> str:
)[0] )[0]
generate_ids = generate_ids[inputs['input_ids'].shape[1]:] # Trim off the prompt generate_ids = generate_ids[inputs['input_ids'].shape[1]:] # Trim off the prompt
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Decode the caption caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Decode the caption
llava_model = llava_model.to(devices.cpu) sd_models.move_model(llava_model, devices.cpu, force=True)
caption = caption.replace('\n\n', '\n').strip() caption = caption.replace('\n\n', '\n').strip()
return caption return caption

View File

@ -7,12 +7,13 @@ import torch
import transformers import transformers
import transformers.dynamic_module_utils import transformers.dynamic_module_utils
from PIL import Image from PIL import Image
from modules import shared, devices, errors, sd_models from modules import shared, devices, errors, sd_models, model_quant
processor = None processor = None
model = None model = None
loaded: str = None loaded: str = None
quant_args = {}
vlm_models = { vlm_models = {
"Microsoft Florence 2 Base": "microsoft/Florence-2-base", # 0.5GB "Microsoft Florence 2 Base": "microsoft/Florence-2-base", # 0.5GB
"Microsoft Florence 2 Large": "microsoft/Florence-2-large", # 1.5GB "Microsoft Florence 2 Large": "microsoft/Florence-2-large", # 1.5GB
@ -121,9 +122,10 @@ def qwen(question: str, image: Image.Image, repo: str = None, system_prompt: str
model = None model = None
model = transformers.Qwen2VLForConditionalGeneration.from_pretrained( model = transformers.Qwen2VLForConditionalGeneration.from_pretrained(
repo, repo,
cache_dir=shared.opts.hfcache_dir torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
) )
model = model.to(devices.device, devices.dtype)
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo loaded = repo
devices.torch_gc() devices.torch_gc()
@ -166,8 +168,12 @@ def gemma(question: str, image: Image.Image, repo: str = None, system_prompt: st
if model is None or loaded != repo: if model is None or loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"') shared.log.debug(f'Interrogate load: vlm="{repo}"')
model = None model = None
model = transformers.Gemma3ForConditionalGeneration.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) model = transformers.Gemma3ForConditionalGeneration.from_pretrained(
model = model.to(devices.device, devices.dtype) repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo loaded = repo
devices.torch_gc() devices.torch_gc()
@ -217,7 +223,6 @@ def paligemma(question: str, image: Image.Image, repo: str = None):
cache_dir=shared.opts.hfcache_dir, cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype, torch_dtype=devices.dtype,
) )
model = model.to(devices.device, devices.dtype)
loaded = repo loaded = repo
devices.torch_gc() devices.torch_gc()
sd_models.move_model(model, devices.device) sd_models.move_model(model, devices.device)
@ -251,7 +256,6 @@ def ovis(question: str, image: Image.Image, repo: str = None):
trust_remote_code=True, trust_remote_code=True,
cache_dir=shared.opts.hfcache_dir, cache_dir=shared.opts.hfcache_dir,
) )
model = model.to(devices.device, devices.dtype)
loaded = repo loaded = repo
devices.torch_gc() devices.torch_gc()
sd_models.move_model(model, devices.device) sd_models.move_model(model, devices.device)
@ -292,8 +296,8 @@ def smol(question: str, image: Image.Image, repo: str = None, system_prompt: str
cache_dir=shared.opts.hfcache_dir, cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype, torch_dtype=devices.dtype,
_attn_implementation="eager", _attn_implementation="eager",
**quant_args,
) )
model.to(devices.device, devices.dtype)
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo loaded = repo
devices.torch_gc() devices.torch_gc()
@ -331,9 +335,9 @@ def git(question: str, image: Image.Image, repo: str = None):
model = None model = None
model = transformers.GitForCausalLM.from_pretrained( model = transformers.GitForCausalLM.from_pretrained(
repo, repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir, cache_dir=shared.opts.hfcache_dir,
) )
model.to(devices.device, devices.dtype)
processor = transformers.GitProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) processor = transformers.GitProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo loaded = repo
devices.torch_gc() devices.torch_gc()
@ -359,9 +363,9 @@ def blip(question: str, image: Image.Image, repo: str = None):
model = None model = None
model = transformers.BlipForQuestionAnswering.from_pretrained( model = transformers.BlipForQuestionAnswering.from_pretrained(
repo, repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir, cache_dir=shared.opts.hfcache_dir,
) )
model.to(devices.device, devices.dtype)
processor = transformers.BlipProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) processor = transformers.BlipProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo loaded = repo
devices.torch_gc() devices.torch_gc()
@ -381,9 +385,9 @@ def vilt(question: str, image: Image.Image, repo: str = None):
model = None model = None
model = transformers.ViltForQuestionAnswering.from_pretrained( model = transformers.ViltForQuestionAnswering.from_pretrained(
repo, repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir, cache_dir=shared.opts.hfcache_dir,
) )
model.to(devices.device)
processor = transformers.ViltProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) processor = transformers.ViltProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo loaded = repo
devices.torch_gc() devices.torch_gc()
@ -405,9 +409,9 @@ def pix(question: str, image: Image.Image, repo: str = None):
model = None model = None
model = transformers.Pix2StructForConditionalGeneration.from_pretrained( model = transformers.Pix2StructForConditionalGeneration.from_pretrained(
repo, repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir, cache_dir=shared.opts.hfcache_dir,
) )
model.to(devices.device)
processor = transformers.Pix2StructProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) processor = transformers.Pix2StructProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo loaded = repo
devices.torch_gc() devices.torch_gc()
@ -431,11 +435,11 @@ def moondream(question: str, image: Image.Image, repo: str = None):
repo, repo,
revision="2025-06-21", revision="2025-06-21",
trust_remote_code=True, trust_remote_code=True,
cache_dir=shared.opts.hfcache_dir torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
) )
processor = transformers.AutoTokenizer.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir) processor = transformers.AutoTokenizer.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo loaded = repo
model.to(devices.device, devices.dtype)
model.eval() model.eval()
devices.torch_gc() devices.torch_gc()
sd_models.move_model(model, devices.device) sd_models.move_model(model, devices.device)
@ -475,12 +479,13 @@ def florence(question: str, image: Image.Image, repo: str = None, revision: str
repo, repo,
trust_remote_code=True, trust_remote_code=True,
revision=revision, revision=revision,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir, cache_dir=shared.opts.hfcache_dir,
**quant_args,
) )
processor = transformers.AutoProcessor.from_pretrained(repo, trust_remote_code=True, revision=revision, cache_dir=shared.opts.hfcache_dir) processor = transformers.AutoProcessor.from_pretrained(repo, trust_remote_code=True, revision=revision, cache_dir=shared.opts.hfcache_dir)
transformers.dynamic_module_utils.get_imports = _get_imports transformers.dynamic_module_utils.get_imports = _get_imports
loaded = repo loaded = repo
model.to(devices.device, devices.dtype)
model.eval() model.eval()
devices.torch_gc() devices.torch_gc()
sd_models.move_model(model, devices.device) sd_models.move_model(model, devices.device)
@ -512,7 +517,6 @@ def sa2(question: str, image: Image.Image, repo: str = None):
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
use_flash_attn=False, use_flash_attn=False,
trust_remote_code=True) trust_remote_code=True)
model = model.to(devices.device, devices.dtype)
model = model.eval() model = model.eval()
processor = transformers.AutoTokenizer.from_pretrained( processor = transformers.AutoTokenizer.from_pretrained(
repo, repo,
@ -539,9 +543,11 @@ def sa2(question: str, image: Image.Image, repo: str = None):
def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:Image.Image=None, model_name:str=None, quiet:bool=False): def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:Image.Image=None, model_name:str=None, quiet:bool=False):
global quant_args # pylint: disable=global-statement
if not quiet: if not quiet:
shared.state.begin('Interrogate') shared.state.begin('Interrogate')
t0 = time.time() t0 = time.time()
quant_args = model_quant.create_config(module='LLM')
model_name = model_name or shared.opts.interrogate_vlm_model model_name = model_name or shared.opts.interrogate_vlm_model
if isinstance(image, list): if isinstance(image, list):
image = image[0] if len(image) > 0 else None image = image[0] if len(image) > 0 else None
@ -560,8 +566,10 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:
if shared.native and shared.sd_loaded: if shared.native and shared.sd_loaded:
from modules.sd_models import apply_balanced_offload # prevent circular import from modules.sd_models import apply_balanced_offload # prevent circular import
apply_balanced_offload(shared.sd_model) apply_balanced_offload(shared.sd_model)
from modules import modelloader from modules import modelloader
modelloader.hf_login() modelloader.hf_login()
try: try:
if model_name is None: if model_name is None:
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected') shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected')
@ -573,6 +581,7 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:
if image is None: if image is None:
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no input image') shared.log.error(f'Interrogate: type=vlm model="{model_name}" no input image')
return '' return ''
if 'git' in vqa_model.lower(): if 'git' in vqa_model.lower():
answer = git(question, image, vqa_model) answer = git(question, image, vqa_model)
elif 'vilt' in vqa_model.lower(): elif 'vilt' in vqa_model.lower():
@ -611,9 +620,10 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:
except Exception as e: except Exception as e:
errors.display(e, 'VQA') errors.display(e, 'VQA')
answer = 'error' answer = 'error'
if shared.opts.interrogate_offload and model is not None: if shared.opts.interrogate_offload and model is not None:
sd_models.move_model(model, devices.cpu) sd_models.move_model(model, devices.cpu, force=True)
devices.torch_gc() devices.torch_gc(force=True)
answer = clean(answer, question) answer = clean(answer, question)
t1 = time.time() t1 = time.time()
if not quiet: if not quiet:

2
wiki

@ -1 +1 @@
Subproject commit b95139063c469ab11a379804ddb606ccfaf6c5c6 Subproject commit 45c389a3747cb1caceb4035535041c400c29cad2