enable quants for vlm-captioning

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4005/head
Vladimir Mandic 2025-06-29 11:48:02 -04:00
parent 99f323d105
commit 1b4e1ff0ef
4 changed files with 42 additions and 23 deletions

View File

@ -36,6 +36,7 @@
- **Changes**
- Update all core requirements
- Support Remote VAE with *Omnigen, Lumina 2 and PixArt*
- Enable quantization for captioning: *Gemma, Qwen, SMOL, Florence, JoyCaption*
- Add `--trace` command line param that enables trace logging
- Use Diffusers version of *OmniGen*
- Control move global settings to control elements -> control settings tab

View File

@ -3,7 +3,7 @@
from dataclasses import dataclass
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from modules import shared, devices
from modules import shared, devices, sd_models, model_quant
"""
@ -66,8 +66,16 @@ def predict(question: str, image, vqa_model: str = None) -> str:
llava_model = None
if llava_model is None:
shared.log.info(f'Interrogate: type=vlm model="JoyCaption" {str(opts)}')
processor = AutoProcessor.from_pretrained(opts.repo)
llava_model = LlavaForConditionalGeneration.from_pretrained(opts.repo, torch_dtype=devices.dtype, device_map="auto", cache_dir=shared.opts.hfcache_dir)
quant_args = model_quant.create_config(module='LLM')
llava_model = LlavaForConditionalGeneration.from_pretrained(
opts.repo,
torch_dtype=devices.dtype,
device_map="auto",
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
llava_model.eval()
if len(question) < 2:
@ -80,7 +88,7 @@ def predict(question: str, image, vqa_model: str = None) -> str:
convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to(devices.device) # Process the inputs
inputs['pixel_values'] = inputs['pixel_values'].to(devices.dtype)
llava_model = llava_model.to(devices.device)
sd_models.move_model(llava_model, devices.device)
with devices.inference_context():
generate_ids = llava_model.generate( # Generate the captions
**inputs,
@ -97,6 +105,6 @@ def predict(question: str, image, vqa_model: str = None) -> str:
)[0]
generate_ids = generate_ids[inputs['input_ids'].shape[1]:] # Trim off the prompt
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Decode the caption
llava_model = llava_model.to(devices.cpu)
sd_models.move_model(llava_model, devices.cpu, force=True)
caption = caption.replace('\n\n', '\n').strip()
return caption

View File

@ -7,12 +7,13 @@ import torch
import transformers
import transformers.dynamic_module_utils
from PIL import Image
from modules import shared, devices, errors, sd_models
from modules import shared, devices, errors, sd_models, model_quant
processor = None
model = None
loaded: str = None
quant_args = {}
vlm_models = {
"Microsoft Florence 2 Base": "microsoft/Florence-2-base", # 0.5GB
"Microsoft Florence 2 Large": "microsoft/Florence-2-large", # 1.5GB
@ -121,9 +122,10 @@ def qwen(question: str, image: Image.Image, repo: str = None, system_prompt: str
model = None
model = transformers.Qwen2VLForConditionalGeneration.from_pretrained(
repo,
cache_dir=shared.opts.hfcache_dir
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
model = model.to(devices.device, devices.dtype)
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo
devices.torch_gc()
@ -166,8 +168,12 @@ def gemma(question: str, image: Image.Image, repo: str = None, system_prompt: st
if model is None or loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
model = None
model = transformers.Gemma3ForConditionalGeneration.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
model = model.to(devices.device, devices.dtype)
model = transformers.Gemma3ForConditionalGeneration.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo
devices.torch_gc()
@ -217,7 +223,6 @@ def paligemma(question: str, image: Image.Image, repo: str = None):
cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype,
)
model = model.to(devices.device, devices.dtype)
loaded = repo
devices.torch_gc()
sd_models.move_model(model, devices.device)
@ -251,7 +256,6 @@ def ovis(question: str, image: Image.Image, repo: str = None):
trust_remote_code=True,
cache_dir=shared.opts.hfcache_dir,
)
model = model.to(devices.device, devices.dtype)
loaded = repo
devices.torch_gc()
sd_models.move_model(model, devices.device)
@ -292,8 +296,8 @@ def smol(question: str, image: Image.Image, repo: str = None, system_prompt: str
cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype,
_attn_implementation="eager",
**quant_args,
)
model.to(devices.device, devices.dtype)
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo
devices.torch_gc()
@ -331,9 +335,9 @@ def git(question: str, image: Image.Image, repo: str = None):
model = None
model = transformers.GitForCausalLM.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
model.to(devices.device, devices.dtype)
processor = transformers.GitProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo
devices.torch_gc()
@ -359,9 +363,9 @@ def blip(question: str, image: Image.Image, repo: str = None):
model = None
model = transformers.BlipForQuestionAnswering.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
model.to(devices.device, devices.dtype)
processor = transformers.BlipProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo
devices.torch_gc()
@ -381,9 +385,9 @@ def vilt(question: str, image: Image.Image, repo: str = None):
model = None
model = transformers.ViltForQuestionAnswering.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
model.to(devices.device)
processor = transformers.ViltProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo
devices.torch_gc()
@ -405,9 +409,9 @@ def pix(question: str, image: Image.Image, repo: str = None):
model = None
model = transformers.Pix2StructForConditionalGeneration.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
model.to(devices.device)
processor = transformers.Pix2StructProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo
devices.torch_gc()
@ -431,11 +435,11 @@ def moondream(question: str, image: Image.Image, repo: str = None):
repo,
revision="2025-06-21",
trust_remote_code=True,
cache_dir=shared.opts.hfcache_dir
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
processor = transformers.AutoTokenizer.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
loaded = repo
model.to(devices.device, devices.dtype)
model.eval()
devices.torch_gc()
sd_models.move_model(model, devices.device)
@ -475,12 +479,13 @@ def florence(question: str, image: Image.Image, repo: str = None, revision: str
repo,
trust_remote_code=True,
revision=revision,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
processor = transformers.AutoProcessor.from_pretrained(repo, trust_remote_code=True, revision=revision, cache_dir=shared.opts.hfcache_dir)
transformers.dynamic_module_utils.get_imports = _get_imports
loaded = repo
model.to(devices.device, devices.dtype)
model.eval()
devices.torch_gc()
sd_models.move_model(model, devices.device)
@ -512,7 +517,6 @@ def sa2(question: str, image: Image.Image, repo: str = None):
low_cpu_mem_usage=True,
use_flash_attn=False,
trust_remote_code=True)
model = model.to(devices.device, devices.dtype)
model = model.eval()
processor = transformers.AutoTokenizer.from_pretrained(
repo,
@ -539,9 +543,11 @@ def sa2(question: str, image: Image.Image, repo: str = None):
def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:Image.Image=None, model_name:str=None, quiet:bool=False):
global quant_args # pylint: disable=global-statement
if not quiet:
shared.state.begin('Interrogate')
t0 = time.time()
quant_args = model_quant.create_config(module='LLM')
model_name = model_name or shared.opts.interrogate_vlm_model
if isinstance(image, list):
image = image[0] if len(image) > 0 else None
@ -560,8 +566,10 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:
if shared.native and shared.sd_loaded:
from modules.sd_models import apply_balanced_offload # prevent circular import
apply_balanced_offload(shared.sd_model)
from modules import modelloader
modelloader.hf_login()
try:
if model_name is None:
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected')
@ -573,6 +581,7 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:
if image is None:
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no input image')
return ''
if 'git' in vqa_model.lower():
answer = git(question, image, vqa_model)
elif 'vilt' in vqa_model.lower():
@ -611,9 +620,10 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:
except Exception as e:
errors.display(e, 'VQA')
answer = 'error'
if shared.opts.interrogate_offload and model is not None:
sd_models.move_model(model, devices.cpu)
devices.torch_gc()
sd_models.move_model(model, devices.cpu, force=True)
devices.torch_gc(force=True)
answer = clean(answer, question)
t1 = time.time()
if not quiet:

2
wiki

@ -1 +1 @@
Subproject commit b95139063c469ab11a379804ddb606ccfaf6c5c6
Subproject commit 45c389a3747cb1caceb4035535041c400c29cad2