Experimental: Add Moondream interrogator?

dev
d8ahazard 2024-04-03 10:46:34 -05:00
parent 8de60cfcb9
commit b8a6ac23d0
3 changed files with 53 additions and 9 deletions

View File

@ -0,0 +1,45 @@
from PIL.Image import Image
from interrogators.interrogator import Interrogator
from model_download import fetch_model
from process_params import ProcessParams
from torch import float16
from transformers import AutoModelForCausalLM, AutoTokenizer
class MoondreamInterrogator(Interrogator):
model = None
tokenizer = None
params = {"interrogation_prompt": "Describe this image in one detailed sentence."}
def __init__(self, params: ProcessParams):
super().__init__(params)
self.params = params
self.model = None
self.tokenizer = None
def interrogate(self, image: Image, params: ProcessParams = None, unload: bool = False) -> str:
self.load()
enc_image = self.model.encode_image(image)
caption = self.model.answer_question(enc_image, "Describe this image in one detailed sentence.", self.tokenizer)
if unload:
self.unload()
return caption
def _to_gpu(self):
self.model = self.model.to("cuda")
def _to_cpu(self):
self.model = self.model.to("cpu")
def load(self):
if self.model is None:
model_id = "vikhyatk/moondream2"
revision = "2024-04-02"
model_path = fetch_model(model_id, "llm")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True,
torch_dtype=float16, attn_implementation="flash_attention_2"
).to("cuda")

View File

@ -4,7 +4,7 @@ import base64
import torch
from transformers import StoppingCriteria
from mplug_owl2.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from icecream import ic
@ -34,7 +34,7 @@ def process_images(images, image_processor, model_cfg=None):
new_images = []
if image_aspect_ratio == 'pad':
for image in images:
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
elif image_aspect_ratio == 'resize':
@ -51,10 +51,11 @@ def process_images(images, image_processor, model_cfg=None):
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in
prompt.split(DEFAULT_IMAGE_TOKEN)]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
@ -81,8 +82,6 @@ def get_model_name_from_path(model_path):
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
@ -109,4 +108,4 @@ class KeywordsStoppingCriteria(StoppingCriteria):
for keyword in self.keywords:
if keyword in outputs:
return True
return False
return False

View File

@ -69,7 +69,7 @@ def save_img_caption(image_path: str, img_caption: str, params: ProcessParams):
os.rename(src_name, backup_name)
if img_caption is not None and len(img_caption) > 0:
with open(src_name, "w", encoding="utf8") as file:
file.write(src_name)
file.write(img_caption)
return src_name
@ -286,7 +286,7 @@ def build_caption(image, captions_list, tags_to_ignore, caption_length, subject_
caption_txt = ", ".join(tags_list)
if subject != "" and insert_subject:
if subject not in caption_txt:
caption_txt = f"{subject}, {caption_txt}"
caption_txt = f"{caption_txt}, {subject}"
return caption_txt