diff --git a/interrogators/moondream_interrogator.py b/interrogators/moondream_interrogator.py new file mode 100644 index 0000000..3e594f3 --- /dev/null +++ b/interrogators/moondream_interrogator.py @@ -0,0 +1,45 @@ +from PIL.Image import Image + +from interrogators.interrogator import Interrogator +from model_download import fetch_model +from process_params import ProcessParams +from torch import float16 +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class MoondreamInterrogator(Interrogator): + model = None + tokenizer = None + params = {"interrogation_prompt": "Describe this image in one detailed sentence."} + + def __init__(self, params: ProcessParams): + super().__init__(params) + self.params = params + self.model = None + self.tokenizer = None + + def interrogate(self, image: Image, params: ProcessParams = None, unload: bool = False) -> str: + self.load() + enc_image = self.model.encode_image(image) + caption = self.model.answer_question(enc_image, "Describe this image in one detailed sentence.", self.tokenizer) + if unload: + self.unload() + return caption + + def _to_gpu(self): + self.model = self.model.to("cuda") + + def _to_cpu(self): + self.model = self.model.to("cpu") + + def load(self): + if self.model is None: + model_id = "vikhyatk/moondream2" + revision = "2024-04-02" + model_path = fetch_model(model_id, "llm") + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + self.model = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, + torch_dtype=float16, attn_implementation="flash_attention_2" + ).to("cuda") diff --git a/mplug_owl2/mm_utils.py b/mplug_owl2/mm_utils.py index 128b2cf..615a1e5 100644 --- a/mplug_owl2/mm_utils.py +++ b/mplug_owl2/mm_utils.py @@ -4,7 +4,7 @@ import base64 import torch from transformers import StoppingCriteria -from mplug_owl2.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN +from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN from icecream import ic @@ -34,7 +34,7 @@ def process_images(images, image_processor, model_cfg=None): new_images = [] if image_aspect_ratio == 'pad': for image in images: - image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean)) image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] new_images.append(image) elif image_aspect_ratio == 'resize': @@ -51,10 +51,11 @@ def process_images(images, image_processor, model_cfg=None): def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): - prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)] + prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in + prompt.split(DEFAULT_IMAGE_TOKEN)] def insert_separator(X, sep): - return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 @@ -81,8 +82,6 @@ def get_model_name_from_path(model_path): return model_paths[-1] - - class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords @@ -109,4 +108,4 @@ class KeywordsStoppingCriteria(StoppingCriteria): for keyword in self.keywords: if keyword in outputs: return True - return False \ No newline at end of file + return False diff --git a/smartprocess.py b/smartprocess.py index fc226f4..11aa0cd 100644 --- a/smartprocess.py +++ b/smartprocess.py @@ -69,7 +69,7 @@ def save_img_caption(image_path: str, img_caption: str, params: ProcessParams): os.rename(src_name, backup_name) if img_caption is not None and len(img_caption) > 0: with open(src_name, "w", encoding="utf8") as file: - file.write(src_name) + file.write(img_caption) return src_name @@ -286,7 +286,7 @@ def build_caption(image, captions_list, tags_to_ignore, caption_length, subject_ caption_txt = ", ".join(tags_list) if subject != "" and insert_subject: if subject not in caption_txt: - caption_txt = f"{subject}, {caption_txt}" + caption_txt = f"{caption_txt}, {subject}" return caption_txt