sd_smartprocess/interrogators/mplug2_interrogator.py

110 lines
4.2 KiB
Python

import gc
import logging
import os
from typing import Dict
import torch
from PIL import Image
from huggingface_hub import snapshot_download
from transformers import TextStreamer
from extensions.sd_smartprocess.interrogators.interrogator import Interrogator
from extensions.sd_smartprocess.model_download import fetch_model
from extensions.sd_smartprocess.process_params import ProcessParams
from modules.paths_internal import models_path
from extensions.sd_smartprocess.mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from extensions.sd_smartprocess.mplug_owl2.conversation import conv_templates
from extensions.sd_smartprocess.mplug_owl2.mm_utils import KeywordsStoppingCriteria, tokenizer_image_token, \
process_images, \
get_model_name_from_path
from extensions.sd_smartprocess.mplug_owl2.model.builder import load_pretrained_model
# This is basically broken until we can update transformers in AUTO past the current version supported
logger = logging.getLogger(__name__)
class MPLUG2Interrogator(Interrogator):
model = None
processor = None
params = {"max_tokens": 75}
def __init__(self, params: ProcessParams):
super().__init__(params)
logger.debug("Initializing LLM model...")
model_path = fetch_model('MAGAer13/mplug-owl2-llama2-7b', "llm")
model_name = get_model_name_from_path(model_path)
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None,
model_name,
load_8bit=False,
load_4bit=False,
device="cuda")
self._to_cpu()
logger.debug("Initialized LLM model.")
def interrogate(self, image: Image, params=None, unload: bool = False) -> str:
self.load()
if params is None:
params = {}
query = "Describe the image with a caption that can be used to generate a similar image."
conv = conv_templates["mplug_owl2"].copy()
roles = conv.roles
max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
image = image.resize((max_edge, max_edge))
image_tensor = process_images([image], self.image_processor)
image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
inp = DEFAULT_IMAGE_TOKEN + query
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
0).to(
self.model.device)
stop_str = conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
temperature = 0.7
max_new_tokens = 512
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria])
caption = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
return caption
def _to_cpu(self):
self.model.to('cpu')
#self.image_processor.to('cpu')
#self.tokenizer.to('cpu')
def _to_gpu(self):
self.model.to(self.device)
#self.image_processor.to(self.device)
#self.tokenizer.to(self.device)
def unload(self):
self._to_cpu()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def load(self):
self._to_gpu()