automatic/modules/vqa.py

155 lines
6.2 KiB
Python

import torch
import transformers
from PIL import Image
from modules import shared, devices
processor = None
model = None
loaded: str = None
MODELS = {
"Moondream 2": "vikhyatk/moondream2", # 3.7GB
"GIT TextCaps Base": "microsoft/git-base-textcaps", # 0.7GB
"GIT VQA Base": "microsoft/git-base-vqav2", # 0.7GB
"GIT VQA Large": "microsoft/git-large-vqav2", # 1.6GB
"BLIP Base": "Salesforce/blip-vqa-base", # 1.5GB
"BLIP Large": "Salesforce/blip-vqa-capfilt-large", # 1.5GB
"ViLT Base": "dandelin/vilt-b32-finetuned-vqa", # 0.5GB
"Pix Textcaps": "google/pix2struct-textcaps-base", # 1.1GB
}
def git(question: str, image: Image.Image, repo: str = None):
global processor, model, loaded # pylint: disable=global-statement
if model is None or loaded != repo:
model = transformers.GitForCausalLM.from_pretrained(repo)
processor = transformers.GitProcessor.from_pretrained(repo)
loaded = repo
model.to(devices.device, devices.dtype)
shared.log.debug(f'VQA: class={model.__class__.__name__} processor={processor.__class__} model={repo}')
pixel_values = processor(images=image, return_tensors="pt").pixel_values
git_dict = {}
git_dict['pixel_values'] = pixel_values.to(devices.device, devices.dtype)
if len(question) > 0:
input_ids = processor(text=question, add_special_tokens=False).input_ids
input_ids = [processor.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
git_dict['input_ids'] = input_ids.to(devices.device)
with devices.inference_context():
generated_ids = model.generate(**git_dict)
response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
shared.log.debug(f'VQA: response={response}')
return response
def blip(question: str, image: Image.Image, repo: str = None):
global processor, model, loaded # pylint: disable=global-statement
if model is None or loaded != repo:
model = transformers.BlipForQuestionAnswering.from_pretrained(repo)
processor = transformers.BlipProcessor.from_pretrained(repo)
loaded = repo
model.to(devices.device, devices.dtype)
inputs = processor(image, question, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
with devices.inference_context():
outputs = model.generate(**inputs)
response = processor.decode(outputs[0], skip_special_tokens=True)
model.to(devices.cpu)
shared.log.debug(f'VQA: response={response}')
return response
def vilt(question: str, image: Image.Image, repo: str = None):
global processor, model, loaded # pylint: disable=global-statement
if model is None or loaded != repo:
model = transformers.ViltForQuestionAnswering.from_pretrained(repo)
processor = transformers.ViltProcessor.from_pretrained(repo)
loaded = repo
model.to(devices.device)
shared.log.debug(f'VQA: class={model.__class__.__name__} processor={processor.__class__} model={repo}')
inputs = processor(image, question, return_tensors="pt")
inputs = inputs.to(devices.device)
with devices.inference_context():
outputs = model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
response = model.config.id2label[idx]
shared.log.debug(f'VQA: response={response}')
return response
def pix(question: str, image: Image.Image, repo: str = None):
global processor, model, loaded # pylint: disable=global-statement
if model is None or loaded != repo:
model = transformers.Pix2StructForConditionalGeneration.from_pretrained(repo)
processor = transformers.Pix2StructProcessor.from_pretrained(repo)
loaded = repo
model.to(devices.device)
shared.log.debug(f'VQA: class={model.__class__.__name__} processor={processor.__class__} model={repo}')
if len(question) > 0:
inputs = processor(images=image, text=question, return_tensors="pt").to(devices.device)
else:
inputs = processor(images=image, return_tensors="pt").to(devices.device)
with devices.inference_context():
outputs = model.generate(**inputs)
response = processor.decode(outputs[0], skip_special_tokens=True)
shared.log.debug(f'VQA: response={response}')
return response
def moondream(question: str, image: Image.Image, repo: str = None):
global processor, model, loaded # pylint: disable=global-statement
if model is None or loaded != repo:
model = transformers.AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True) # revision = "2024-03-05"
processor = transformers.AutoTokenizer.from_pretrained(repo) # revision = "2024-03-05"
loaded = repo
model.eval()
model.to(devices.device, devices.dtype)
shared.log.debug(f'VQA: class={model.__class__.__name__} processor={processor.__class__} model={repo}')
if len(question) < 2:
question = "Describe the image."
encoded = model.encode_image(image)
with devices.inference_context():
response = model.answer_question(encoded, question, processor)
shared.log.debug(f'VQA: response="{response}"')
return response
def interrogate(vqa_question, vqa_image, vqa_model_req):
vqa_model = MODELS.get(vqa_model_req, None)
shared.log.debug(f'VQA: model="{vqa_model}" question="{vqa_question}" image={vqa_image}')
if vqa_image is None:
answer = 'no image provided'
return answer
if vqa_model_req is None:
answer = 'no model selected'
return answer
if vqa_model is None:
answer = f'unknown: model={vqa_model_req} available={MODELS.keys()}'
return answer
if 'git' in vqa_model.lower():
answer = git(vqa_question, vqa_image, vqa_model)
if 'vilt' in vqa_model.lower():
answer = vilt(vqa_question, vqa_image, vqa_model)
if 'blip' in vqa_model.lower():
answer = blip(vqa_question, vqa_image, vqa_model)
if 'pix' in vqa_model.lower():
answer = pix(vqa_question, vqa_image, vqa_model)
if 'moondream2' in vqa_model.lower():
answer = moondream(vqa_question, vqa_image, vqa_model)
else:
answer = 'unknown model'
if model is not None:
model.to(devices.cpu)
devices.torch_gc()
return answer