mirror of https://github.com/vladmandic/automatic
enable quants for vlm-captioning
Signed-off-by: Vladimir Mandic <mandic00@live.com>pull/4005/head
parent
99f323d105
commit
1b4e1ff0ef
|
|
@ -36,6 +36,7 @@
|
|||
- **Changes**
|
||||
- Update all core requirements
|
||||
- Support Remote VAE with *Omnigen, Lumina 2 and PixArt*
|
||||
- Enable quantization for captioning: *Gemma, Qwen, SMOL, Florence, JoyCaption*
|
||||
- Add `--trace` command line param that enables trace logging
|
||||
- Use Diffusers version of *OmniGen*
|
||||
- Control move global settings to control elements -> control settings tab
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
from dataclasses import dataclass
|
||||
import torch
|
||||
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||
from modules import shared, devices
|
||||
from modules import shared, devices, sd_models, model_quant
|
||||
|
||||
|
||||
"""
|
||||
|
|
@ -66,8 +66,16 @@ def predict(question: str, image, vqa_model: str = None) -> str:
|
|||
llava_model = None
|
||||
if llava_model is None:
|
||||
shared.log.info(f'Interrogate: type=vlm model="JoyCaption" {str(opts)}')
|
||||
|
||||
processor = AutoProcessor.from_pretrained(opts.repo)
|
||||
llava_model = LlavaForConditionalGeneration.from_pretrained(opts.repo, torch_dtype=devices.dtype, device_map="auto", cache_dir=shared.opts.hfcache_dir)
|
||||
quant_args = model_quant.create_config(module='LLM')
|
||||
llava_model = LlavaForConditionalGeneration.from_pretrained(
|
||||
opts.repo,
|
||||
torch_dtype=devices.dtype,
|
||||
device_map="auto",
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**quant_args,
|
||||
)
|
||||
llava_model.eval()
|
||||
|
||||
if len(question) < 2:
|
||||
|
|
@ -80,7 +88,7 @@ def predict(question: str, image, vqa_model: str = None) -> str:
|
|||
convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
|
||||
inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to(devices.device) # Process the inputs
|
||||
inputs['pixel_values'] = inputs['pixel_values'].to(devices.dtype)
|
||||
llava_model = llava_model.to(devices.device)
|
||||
sd_models.move_model(llava_model, devices.device)
|
||||
with devices.inference_context():
|
||||
generate_ids = llava_model.generate( # Generate the captions
|
||||
**inputs,
|
||||
|
|
@ -97,6 +105,6 @@ def predict(question: str, image, vqa_model: str = None) -> str:
|
|||
)[0]
|
||||
generate_ids = generate_ids[inputs['input_ids'].shape[1]:] # Trim off the prompt
|
||||
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Decode the caption
|
||||
llava_model = llava_model.to(devices.cpu)
|
||||
sd_models.move_model(llava_model, devices.cpu, force=True)
|
||||
caption = caption.replace('\n\n', '\n').strip()
|
||||
return caption
|
||||
|
|
|
|||
|
|
@ -7,12 +7,13 @@ import torch
|
|||
import transformers
|
||||
import transformers.dynamic_module_utils
|
||||
from PIL import Image
|
||||
from modules import shared, devices, errors, sd_models
|
||||
from modules import shared, devices, errors, sd_models, model_quant
|
||||
|
||||
|
||||
processor = None
|
||||
model = None
|
||||
loaded: str = None
|
||||
quant_args = {}
|
||||
vlm_models = {
|
||||
"Microsoft Florence 2 Base": "microsoft/Florence-2-base", # 0.5GB
|
||||
"Microsoft Florence 2 Large": "microsoft/Florence-2-large", # 1.5GB
|
||||
|
|
@ -121,9 +122,10 @@ def qwen(question: str, image: Image.Image, repo: str = None, system_prompt: str
|
|||
model = None
|
||||
model = transformers.Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
repo,
|
||||
cache_dir=shared.opts.hfcache_dir
|
||||
torch_dtype=devices.dtype,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**quant_args,
|
||||
)
|
||||
model = model.to(devices.device, devices.dtype)
|
||||
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
||||
loaded = repo
|
||||
devices.torch_gc()
|
||||
|
|
@ -166,8 +168,12 @@ def gemma(question: str, image: Image.Image, repo: str = None, system_prompt: st
|
|||
if model is None or loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
model = None
|
||||
model = transformers.Gemma3ForConditionalGeneration.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
||||
model = model.to(devices.device, devices.dtype)
|
||||
model = transformers.Gemma3ForConditionalGeneration.from_pretrained(
|
||||
repo,
|
||||
torch_dtype=devices.dtype,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**quant_args,
|
||||
)
|
||||
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
||||
loaded = repo
|
||||
devices.torch_gc()
|
||||
|
|
@ -217,7 +223,6 @@ def paligemma(question: str, image: Image.Image, repo: str = None):
|
|||
cache_dir=shared.opts.hfcache_dir,
|
||||
torch_dtype=devices.dtype,
|
||||
)
|
||||
model = model.to(devices.device, devices.dtype)
|
||||
loaded = repo
|
||||
devices.torch_gc()
|
||||
sd_models.move_model(model, devices.device)
|
||||
|
|
@ -251,7 +256,6 @@ def ovis(question: str, image: Image.Image, repo: str = None):
|
|||
trust_remote_code=True,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
)
|
||||
model = model.to(devices.device, devices.dtype)
|
||||
loaded = repo
|
||||
devices.torch_gc()
|
||||
sd_models.move_model(model, devices.device)
|
||||
|
|
@ -292,8 +296,8 @@ def smol(question: str, image: Image.Image, repo: str = None, system_prompt: str
|
|||
cache_dir=shared.opts.hfcache_dir,
|
||||
torch_dtype=devices.dtype,
|
||||
_attn_implementation="eager",
|
||||
**quant_args,
|
||||
)
|
||||
model.to(devices.device, devices.dtype)
|
||||
processor = transformers.AutoProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
||||
loaded = repo
|
||||
devices.torch_gc()
|
||||
|
|
@ -331,9 +335,9 @@ def git(question: str, image: Image.Image, repo: str = None):
|
|||
model = None
|
||||
model = transformers.GitForCausalLM.from_pretrained(
|
||||
repo,
|
||||
torch_dtype=devices.dtype,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
)
|
||||
model.to(devices.device, devices.dtype)
|
||||
processor = transformers.GitProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
||||
loaded = repo
|
||||
devices.torch_gc()
|
||||
|
|
@ -359,9 +363,9 @@ def blip(question: str, image: Image.Image, repo: str = None):
|
|||
model = None
|
||||
model = transformers.BlipForQuestionAnswering.from_pretrained(
|
||||
repo,
|
||||
torch_dtype=devices.dtype,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
)
|
||||
model.to(devices.device, devices.dtype)
|
||||
processor = transformers.BlipProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
||||
loaded = repo
|
||||
devices.torch_gc()
|
||||
|
|
@ -381,9 +385,9 @@ def vilt(question: str, image: Image.Image, repo: str = None):
|
|||
model = None
|
||||
model = transformers.ViltForQuestionAnswering.from_pretrained(
|
||||
repo,
|
||||
torch_dtype=devices.dtype,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
)
|
||||
model.to(devices.device)
|
||||
processor = transformers.ViltProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
||||
loaded = repo
|
||||
devices.torch_gc()
|
||||
|
|
@ -405,9 +409,9 @@ def pix(question: str, image: Image.Image, repo: str = None):
|
|||
model = None
|
||||
model = transformers.Pix2StructForConditionalGeneration.from_pretrained(
|
||||
repo,
|
||||
torch_dtype=devices.dtype,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
)
|
||||
model.to(devices.device)
|
||||
processor = transformers.Pix2StructProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
||||
loaded = repo
|
||||
devices.torch_gc()
|
||||
|
|
@ -431,11 +435,11 @@ def moondream(question: str, image: Image.Image, repo: str = None):
|
|||
repo,
|
||||
revision="2025-06-21",
|
||||
trust_remote_code=True,
|
||||
cache_dir=shared.opts.hfcache_dir
|
||||
torch_dtype=devices.dtype,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
)
|
||||
processor = transformers.AutoTokenizer.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
||||
loaded = repo
|
||||
model.to(devices.device, devices.dtype)
|
||||
model.eval()
|
||||
devices.torch_gc()
|
||||
sd_models.move_model(model, devices.device)
|
||||
|
|
@ -475,12 +479,13 @@ def florence(question: str, image: Image.Image, repo: str = None, revision: str
|
|||
repo,
|
||||
trust_remote_code=True,
|
||||
revision=revision,
|
||||
torch_dtype=devices.dtype,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**quant_args,
|
||||
)
|
||||
processor = transformers.AutoProcessor.from_pretrained(repo, trust_remote_code=True, revision=revision, cache_dir=shared.opts.hfcache_dir)
|
||||
transformers.dynamic_module_utils.get_imports = _get_imports
|
||||
loaded = repo
|
||||
model.to(devices.device, devices.dtype)
|
||||
model.eval()
|
||||
devices.torch_gc()
|
||||
sd_models.move_model(model, devices.device)
|
||||
|
|
@ -512,7 +517,6 @@ def sa2(question: str, image: Image.Image, repo: str = None):
|
|||
low_cpu_mem_usage=True,
|
||||
use_flash_attn=False,
|
||||
trust_remote_code=True)
|
||||
model = model.to(devices.device, devices.dtype)
|
||||
model = model.eval()
|
||||
processor = transformers.AutoTokenizer.from_pretrained(
|
||||
repo,
|
||||
|
|
@ -539,9 +543,11 @@ def sa2(question: str, image: Image.Image, repo: str = None):
|
|||
|
||||
|
||||
def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:Image.Image=None, model_name:str=None, quiet:bool=False):
|
||||
global quant_args # pylint: disable=global-statement
|
||||
if not quiet:
|
||||
shared.state.begin('Interrogate')
|
||||
t0 = time.time()
|
||||
quant_args = model_quant.create_config(module='LLM')
|
||||
model_name = model_name or shared.opts.interrogate_vlm_model
|
||||
if isinstance(image, list):
|
||||
image = image[0] if len(image) > 0 else None
|
||||
|
|
@ -560,8 +566,10 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:
|
|||
if shared.native and shared.sd_loaded:
|
||||
from modules.sd_models import apply_balanced_offload # prevent circular import
|
||||
apply_balanced_offload(shared.sd_model)
|
||||
|
||||
from modules import modelloader
|
||||
modelloader.hf_login()
|
||||
|
||||
try:
|
||||
if model_name is None:
|
||||
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected')
|
||||
|
|
@ -573,6 +581,7 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:
|
|||
if image is None:
|
||||
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no input image')
|
||||
return ''
|
||||
|
||||
if 'git' in vqa_model.lower():
|
||||
answer = git(question, image, vqa_model)
|
||||
elif 'vilt' in vqa_model.lower():
|
||||
|
|
@ -611,9 +620,10 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image:
|
|||
except Exception as e:
|
||||
errors.display(e, 'VQA')
|
||||
answer = 'error'
|
||||
|
||||
if shared.opts.interrogate_offload and model is not None:
|
||||
sd_models.move_model(model, devices.cpu)
|
||||
devices.torch_gc()
|
||||
sd_models.move_model(model, devices.cpu, force=True)
|
||||
devices.torch_gc(force=True)
|
||||
answer = clean(answer, question)
|
||||
t1 = time.time()
|
||||
if not quiet:
|
||||
|
|
|
|||
2
wiki
2
wiki
|
|
@ -1 +1 @@
|
|||
Subproject commit b95139063c469ab11a379804ddb606ccfaf6c5c6
|
||||
Subproject commit 45c389a3747cb1caceb4035535041c400c29cad2
|
||||
Loading…
Reference in New Issue