Experimental: Add Moondream interrogator?
parent
8de60cfcb9
commit
b8a6ac23d0
|
|
@ -0,0 +1,45 @@
|
|||
from PIL.Image import Image
|
||||
|
||||
from interrogators.interrogator import Interrogator
|
||||
from model_download import fetch_model
|
||||
from process_params import ProcessParams
|
||||
from torch import float16
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
class MoondreamInterrogator(Interrogator):
|
||||
model = None
|
||||
tokenizer = None
|
||||
params = {"interrogation_prompt": "Describe this image in one detailed sentence."}
|
||||
|
||||
def __init__(self, params: ProcessParams):
|
||||
super().__init__(params)
|
||||
self.params = params
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
||||
def interrogate(self, image: Image, params: ProcessParams = None, unload: bool = False) -> str:
|
||||
self.load()
|
||||
enc_image = self.model.encode_image(image)
|
||||
caption = self.model.answer_question(enc_image, "Describe this image in one detailed sentence.", self.tokenizer)
|
||||
if unload:
|
||||
self.unload()
|
||||
return caption
|
||||
|
||||
def _to_gpu(self):
|
||||
self.model = self.model.to("cuda")
|
||||
|
||||
def _to_cpu(self):
|
||||
self.model = self.model.to("cpu")
|
||||
|
||||
def load(self):
|
||||
if self.model is None:
|
||||
model_id = "vikhyatk/moondream2"
|
||||
revision = "2024-04-02"
|
||||
model_path = fetch_model(model_id, "llm")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, trust_remote_code=True,
|
||||
torch_dtype=float16, attn_implementation="flash_attention_2"
|
||||
).to("cuda")
|
||||
|
|
@ -4,7 +4,7 @@ import base64
|
|||
|
||||
import torch
|
||||
from transformers import StoppingCriteria
|
||||
from mplug_owl2.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
|
||||
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
||||
from icecream import ic
|
||||
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ def process_images(images, image_processor, model_cfg=None):
|
|||
new_images = []
|
||||
if image_aspect_ratio == 'pad':
|
||||
for image in images:
|
||||
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
||||
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
|
||||
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
||||
new_images.append(image)
|
||||
elif image_aspect_ratio == 'resize':
|
||||
|
|
@ -51,10 +51,11 @@ def process_images(images, image_processor, model_cfg=None):
|
|||
|
||||
|
||||
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
||||
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
|
||||
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in
|
||||
prompt.split(DEFAULT_IMAGE_TOKEN)]
|
||||
|
||||
def insert_separator(X, sep):
|
||||
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
||||
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
||||
|
||||
input_ids = []
|
||||
offset = 0
|
||||
|
|
@ -81,8 +82,6 @@ def get_model_name_from_path(model_path):
|
|||
return model_paths[-1]
|
||||
|
||||
|
||||
|
||||
|
||||
class KeywordsStoppingCriteria(StoppingCriteria):
|
||||
def __init__(self, keywords, tokenizer, input_ids):
|
||||
self.keywords = keywords
|
||||
|
|
@ -109,4 +108,4 @@ class KeywordsStoppingCriteria(StoppingCriteria):
|
|||
for keyword in self.keywords:
|
||||
if keyword in outputs:
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def save_img_caption(image_path: str, img_caption: str, params: ProcessParams):
|
|||
os.rename(src_name, backup_name)
|
||||
if img_caption is not None and len(img_caption) > 0:
|
||||
with open(src_name, "w", encoding="utf8") as file:
|
||||
file.write(src_name)
|
||||
file.write(img_caption)
|
||||
return src_name
|
||||
|
||||
|
||||
|
|
@ -286,7 +286,7 @@ def build_caption(image, captions_list, tags_to_ignore, caption_length, subject_
|
|||
caption_txt = ", ".join(tags_list)
|
||||
if subject != "" and insert_subject:
|
||||
if subject not in caption_txt:
|
||||
caption_txt = f"{subject}, {caption_txt}"
|
||||
caption_txt = f"{caption_txt}, {subject}"
|
||||
return caption_txt
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue