Fix Mplug owl2
parent
2f728a0dc8
commit
81dcb53ccb
|
|
@ -9,13 +9,15 @@ from huggingface_hub import snapshot_download
|
|||
from transformers import TextStreamer
|
||||
|
||||
from extensions.sd_smartprocess.interrogators.interrogator import Interrogator
|
||||
from extensions.sd_smartprocess.model_download import fetch_model
|
||||
from extensions.sd_smartprocess.process_params import ProcessParams
|
||||
from modules.paths_internal import models_path
|
||||
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
||||
from mplug_owl2.conversation import conv_templates
|
||||
from mplug_owl2.mm_utils import KeywordsStoppingCriteria, tokenizer_image_token, process_images, \
|
||||
from extensions.sd_smartprocess.mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
||||
from extensions.sd_smartprocess.mplug_owl2.conversation import conv_templates
|
||||
from extensions.sd_smartprocess.mplug_owl2.mm_utils import KeywordsStoppingCriteria, tokenizer_image_token, \
|
||||
process_images, \
|
||||
get_model_name_from_path
|
||||
from mplug_owl2.model.builder import load_pretrained_model
|
||||
from extensions.sd_smartprocess.mplug_owl2.model.builder import load_pretrained_model
|
||||
|
||||
# This is basically broken until we can update transformers in AUTO past the current version supported
|
||||
|
||||
|
|
@ -30,19 +32,13 @@ class MPLUG2Interrogator(Interrogator):
|
|||
def __init__(self, params: ProcessParams):
|
||||
super().__init__(params)
|
||||
logger.debug("Initializing LLM model...")
|
||||
pretrained_ckpt = 'MAGAer13/mplug-owl2-llama2-7b'
|
||||
scripts_dir = os.path.join(models_path, "llm")
|
||||
os.makedirs(scripts_dir, exist_ok=True)
|
||||
model_name = "mplug-owl2-llama2-7b"
|
||||
model_path = os.path.join(scripts_dir, model_name)
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
snapshot_download(pretrained_ckpt, repo_type="model", local_dir=model_path, local_dir_use_symlinks=False)
|
||||
|
||||
model_path = fetch_model('MAGAer13/mplug-owl2-llama2-7b', "llm")
|
||||
model_name = get_model_name_from_path(model_path)
|
||||
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None, model_name,
|
||||
load_8bit=False, load_4bit=False,
|
||||
device="cuda")
|
||||
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None,
|
||||
model_name,
|
||||
load_8bit=False,
|
||||
load_4bit=False,
|
||||
device="cuda")
|
||||
|
||||
self._to_cpu()
|
||||
logger.debug("Initialized LLM model.")
|
||||
|
|
@ -51,7 +47,7 @@ class MPLUG2Interrogator(Interrogator):
|
|||
self.load()
|
||||
if params is None:
|
||||
params = {}
|
||||
query = params.get("query", "Describe the image.")
|
||||
query = "Describe the image with a caption that can be used to generate a similar image."
|
||||
|
||||
conv = conv_templates["mplug_owl2"].copy()
|
||||
roles = conv.roles
|
||||
|
|
@ -67,7 +63,8 @@ class MPLUG2Interrogator(Interrogator):
|
|||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(
|
||||
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
|
||||
0).to(
|
||||
self.model.device)
|
||||
stop_str = conv.sep2
|
||||
keywords = [stop_str]
|
||||
|
|
@ -94,13 +91,13 @@ class MPLUG2Interrogator(Interrogator):
|
|||
|
||||
def _to_cpu(self):
|
||||
self.model.to('cpu')
|
||||
self.image_processor.to('cpu')
|
||||
self.tokenizer.to('cpu')
|
||||
#self.image_processor.to('cpu')
|
||||
#self.tokenizer.to('cpu')
|
||||
|
||||
def _to_gpu(self):
|
||||
self.model.to(self.device)
|
||||
self.image_processor.to(self.device)
|
||||
self.tokenizer.to(self.device)
|
||||
#self.image_processor.to(self.device)
|
||||
#self.tokenizer.to(self.device)
|
||||
|
||||
def unload(self):
|
||||
self._to_cpu()
|
||||
|
|
|
|||
|
|
@ -193,7 +193,7 @@ class MplugOwlConfig(PretrainedConfig):
|
|||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
>>> from transformers.models.owlvit import (
|
||||
... MplugOwlVisionConfig,
|
||||
... MplugOwlVisualAbstractorConfig,
|
||||
... OPTConfig,
|
||||
|
|
@ -236,7 +236,7 @@ class MplugOwlConfig(PretrainedConfig):
|
|||
|
||||
if text_config is None:
|
||||
# we use LLAMA 7b by default
|
||||
from transformers.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
text_config = LlamaConfig(pad_token_id=2).to_dict()
|
||||
logger.info("text_config is None.")
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
""" PyTorch MplugOwl model."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -1651,7 +1652,7 @@ def bloom_forward(
|
|||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
|
|
|||
|
|
@ -118,13 +118,14 @@ if __name__ == '__main__':
|
|||
|
||||
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
|
||||
|
||||
prompt = 'USER: <|image|>Provide a one-sentence caption for the provided image. ASSISTANT:'
|
||||
prompt = 'USER: <|image|>Provide a one-sentence caption for the provided image. ASSISTANT: '
|
||||
|
||||
model_path = args.checkpoint
|
||||
model_name = get_model_name_from_path(model_path)
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map="cuda", device="cuda")
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda")
|
||||
tokenizer.padding_side = 'left'
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
if not hasattr(tokenizer, 'pad_token_id'):
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
random.seed(args.seed)
|
||||
dataset = CaptionDataset(
|
||||
|
|
@ -146,14 +147,14 @@ if __name__ == '__main__':
|
|||
|
||||
image_ids = []
|
||||
captions = []
|
||||
for _, (ids, image_tensor, input_ids, attention_mask) in tqdm(enumerate(coco_karpathy_test_loader)):
|
||||
for _, (ids, image_tensor, input_ids, attention_mask) in enumerate(tqdm(coco_karpathy_test_loader)):
|
||||
pred = model.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
attention_mask=attention_mask.cuda(),
|
||||
images=image_tensor.to(dtype=model.dtype).cuda(),
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_new_tokens=30,
|
||||
max_new_tokens=60,
|
||||
min_new_tokens=8,
|
||||
length_penalty=0,
|
||||
num_return_sequences=1,
|
||||
|
|
@ -164,7 +165,7 @@ if __name__ == '__main__':
|
|||
tokenizer.decode(_[input_ids.size(1):].cpu(),
|
||||
skip_special_tokens=True).strip() for _ in pred
|
||||
])
|
||||
print(captions)
|
||||
print(captions[-len(pred):])
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,21 @@ ds_collections = {
|
|||
'annotation': 'mmbench_test_20230712.jsonl',
|
||||
'max_new_tokens': 10,
|
||||
},
|
||||
'mmbench_test_en_20231003': {
|
||||
'raw_file': 'mmbench_test_en_20231003.tsv',
|
||||
'annotation': 'mmbench_test_en_20231003.jsonl',
|
||||
'max_new_tokens': 10,
|
||||
},
|
||||
'mmbench_test_cn_20231003': {
|
||||
'raw_file': 'mmbench_test_cn_20231003.tsv',
|
||||
'annotation': 'mmbench_test_cn_20231003.jsonl',
|
||||
'max_new_tokens': 10,
|
||||
},
|
||||
'ccbench_1003': {
|
||||
'raw_file': 'ccbench_1003.tsv',
|
||||
'annotation': 'ccbench_1003.jsonl',
|
||||
'max_new_tokens': 10,
|
||||
},
|
||||
}
|
||||
|
||||
multiple_choices = ['A', 'B', 'C', 'D', 'E']
|
||||
|
|
@ -45,7 +60,7 @@ def mapping_to_annotation(results, raw_annotation):
|
|||
"answer": row_df.get('answer', None),
|
||||
"options": [y for y in [row_df.get(x, None) for x in 'ABCD'] if isinstance(y, str)],
|
||||
"prediction": prediction,
|
||||
"l2-category": row_df['l2-category']
|
||||
"l2-category": row_df['l2-category'] if 'l2-category' in row_df else None
|
||||
}
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
|
@ -163,6 +178,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--seed', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl',
|
||||
world_size=int(os.getenv('WORLD_SIZE', '1')),
|
||||
|
|
@ -175,11 +191,12 @@ if __name__ == '__main__':
|
|||
model_path = args.checkpoint
|
||||
model_name = get_model_name_from_path(model_path)
|
||||
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map="cuda", device="cuda")
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda")
|
||||
tokenizer.padding_side = 'left'
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
if not hasattr(tokenizer, 'pad_token_id'):
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
prompt = "USER: <|image|>{}\n{}\n{}\nAnswer with the option’s letter from the given choices directly. ASSISTANT:"
|
||||
prompt = "USER: <|image|>{}\n{}\n{}\nAnswer with the option’s letter from the given choices directly. ASSISTANT: "
|
||||
|
||||
random.seed(args.seed)
|
||||
dataset = VQADataset(
|
||||
|
|
@ -199,7 +216,7 @@ if __name__ == '__main__':
|
|||
)
|
||||
|
||||
outputs = []
|
||||
for _, (image_tensor, input_ids, attention_mask, indices) in tqdm(enumerate(dataloader)):
|
||||
for _, (image_tensor, input_ids, attention_mask, indices) in enumerate(tqdm(dataloader)):
|
||||
pred = model.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
attention_mask=attention_mask.cuda(),
|
||||
|
|
|
|||
|
|
@ -296,11 +296,12 @@ if __name__ == '__main__':
|
|||
model_path = args.checkpoint
|
||||
model_name = get_model_name_from_path(model_path)
|
||||
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map="cuda", device="cuda")
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda")
|
||||
tokenizer.padding_side = 'left'
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
if not hasattr(tokenizer, 'pad_token_id'):
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
prompt = 'USER: <|image|>{}\nAnswer the question using a single word or phrase. ASSISTANT:'
|
||||
prompt = 'USER: <|image|>{}\nAnswer the question using a single word or phrase. ASSISTANT: '
|
||||
|
||||
random.seed(args.seed)
|
||||
dataset = VQADataset(
|
||||
|
|
|
|||
|
|
@ -13,12 +13,13 @@ from PIL import Image
|
|||
import pandas as pd
|
||||
import re
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
||||
from mplug_owl2.conversation import conv_templates, SeparatorStyle
|
||||
from mplug_owl2.model.builder import load_pretrained_model
|
||||
from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
||||
from pathlib import Path
|
||||
from datasets import load_dataset, concatenate_datasets
|
||||
|
||||
DOMAIN_CAT2SUB_CAT = {
|
||||
'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'],
|
||||
|
|
@ -317,11 +318,9 @@ def collate_fn(batches, tokenizer):
|
|||
for input_text in questions:
|
||||
input_ids.append(tokenizer_image_token(input_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').tolist())
|
||||
input_tokens_max_length = max([len(x) for x in input_ids])
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
input_ids = [([pad_token_id] * (input_tokens_max_length - len(_)) + _) for _ in input_ids] # pad in the left
|
||||
input_ids = [([tokenizer.pad_token_id] * (input_tokens_max_length - len(_)) + _) for _ in input_ids] # pad in the left
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
attention_mask = 1 - input_ids.eq(pad_token_id).long()
|
||||
attention_mask = 1 - input_ids.eq(tokenizer.pad_token_id).long()
|
||||
|
||||
image_tensor = torch.cat(image_tensor, dim=0)
|
||||
return image_tensor, input_ids, attention_mask, answers, ids, origin_questions, question_types, subfields, question_splits
|
||||
|
|
@ -332,7 +331,15 @@ class VQADataset(torch.utils.data.Dataset):
|
|||
def __init__(self, split, image_processor, eval_split='dev'):
|
||||
|
||||
self.image_processor = image_processor
|
||||
self.data = load_dataset("MMMU/MMMU", split)[eval_split]
|
||||
# self.data = load_dataset("/nas-alinlp/qinghao.yqh/datasets/mm_chatgpt/Evaluation/MMMU/MMMU", split)[eval_split]
|
||||
sub_dataset_list = []
|
||||
for subject in CAT_SHORT2LONG.values():
|
||||
sub_dataset = load_dataset(str(Path("/nas-alinlp/qinghao.yqh/datasets/mm_chatgpt/Evaluation/MMMU/MMMU", subject)), split=eval_split)
|
||||
sub_dataset_list.append(sub_dataset)
|
||||
|
||||
# merge all dataset
|
||||
self.data = concatenate_datasets(sub_dataset_list)
|
||||
|
||||
self.question_split = split
|
||||
|
||||
def __len__(self):
|
||||
|
|
@ -352,9 +359,9 @@ class VQADataset(torch.utils.data.Dataset):
|
|||
for i, c in enumerate(choices):
|
||||
choice_list.append('{}. {}'.format(multiple_choices[i], c))
|
||||
choice_txt = '\n'.join(choice_list)
|
||||
prompt = f"USER: {question}\n{choice_txt}\nAnswer with the option’s letter from the given choices directly. ASSISTANT:"
|
||||
prompt = f"USER: {question}\n{choice_txt}\nAnswer with the option’s letter from the given choices directly. ASSISTANT: "
|
||||
else:
|
||||
prompt = f"USER: {question}\nAnswer the question using a single word or phrase. ASSISTANT:"
|
||||
prompt = f"USER: {question}\nAnswer the question using a single word or phrase. ASSISTANT: "
|
||||
|
||||
image_nums = re.findall(r'<image (\d+)>', prompt)
|
||||
images = []
|
||||
|
|
@ -430,9 +437,10 @@ if __name__ == '__main__':
|
|||
model_path = args.checkpoint
|
||||
model_name = get_model_name_from_path(model_path)
|
||||
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map="cuda", device="cuda")
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda")
|
||||
if not hasattr(tokenizer, 'pad_token_id'):
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer.padding_side = 'left'
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
|
||||
random.seed(args.seed)
|
||||
|
|
|
|||
|
|
@ -154,12 +154,13 @@ if __name__ == '__main__':
|
|||
model_path = args.checkpoint
|
||||
model_name = get_model_name_from_path(model_path)
|
||||
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map="cuda", device="cuda")
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device_map={"":f"cuda:{os.getenv('LOCAL_RANK', '0')}"}, device="cuda")
|
||||
tokenizer.padding_side = 'left'
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
prompt = 'USER: <|image|>{}\nAnswer the question using a single word or phrase. ASSISTANT:'
|
||||
if not hasattr(tokenizer, 'pad_token_id'):
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
prompt = 'USER: <|image|>{} Answer the question using a single word or phrase. ASSISTANT: '
|
||||
answer_processor = EvalAIAnswerProcessor()
|
||||
random.seed(args.seed)
|
||||
dataset = VQADataset(
|
||||
train=ds_collections[args.dataset]['train'],
|
||||
|
|
@ -181,13 +182,13 @@ if __name__ == '__main__':
|
|||
|
||||
outputs = []
|
||||
for _, (question_ids, image_tensor, input_ids, attention_mask,
|
||||
annotations) in tqdm(enumerate(dataloader)):
|
||||
annotations) in enumerate(tqdm(dataloader)):
|
||||
pred = model.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
attention_mask=attention_mask.cuda(),
|
||||
images=image_tensor.to(dtype=model.dtype).cuda(),
|
||||
do_sample=False,
|
||||
num_beams=5,
|
||||
num_beams=1,
|
||||
max_new_tokens=ds_collections[args.dataset]['max_new_tokens'],
|
||||
min_new_tokens=1,
|
||||
length_penalty=1,
|
||||
|
|
@ -202,11 +203,16 @@ if __name__ == '__main__':
|
|||
|
||||
for question_id, answer, annotation in zip(question_ids, answers,
|
||||
annotations):
|
||||
if args.dataset in ['vqav2_val', 'vqav2_testdev', 'okvqa_val', 'textvqa_val']:
|
||||
if args.dataset in ['vqav2_val', 'okvqa_val', 'textvqa_val']:
|
||||
outputs.append({
|
||||
'question_id': question_id,
|
||||
'answer': answer,
|
||||
})
|
||||
elif args.dataset == 'vqav2_testdev':
|
||||
outputs.append({
|
||||
'question_id': question_id,
|
||||
'answer': answer_processor(answer),
|
||||
})
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -223,7 +229,7 @@ if __name__ == '__main__':
|
|||
print(f"Evaluating {args.dataset} ...")
|
||||
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
|
||||
results_file = f'{args.dataset}_{time_prefix}_fs{args.few_shot}_s{args.seed}.json'
|
||||
json.dump(merged_outputs, open(results_file, 'w'), ensure_ascii=False)
|
||||
json.dump(merged_outputs, open(results_file, 'w', encoding='utf-8'), ensure_ascii=False)
|
||||
|
||||
if ds_collections[args.dataset]['metric'] == 'vqa_score':
|
||||
vqa = VQA(ds_collections[args.dataset]['annotation'],
|
||||
|
|
|
|||
|
|
@ -327,4 +327,225 @@ class VQAEval:
|
|||
'#' * block + '-' * (barLength - block), int(progress * 100),
|
||||
status)
|
||||
sys.stdout.write(text)
|
||||
sys.stdout.flush()
|
||||
sys.stdout.flush()
|
||||
|
||||
import re
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class EvalAIAnswerProcessor:
|
||||
"""
|
||||
Processes an answer similar to Eval AI
|
||||
copied from
|
||||
https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
|
||||
"""
|
||||
|
||||
CONTRACTIONS = {
|
||||
"aint": "ain't",
|
||||
"arent": "aren't",
|
||||
"cant": "can't",
|
||||
"couldve": "could've",
|
||||
"couldnt": "couldn't",
|
||||
"couldn'tve": "couldn't've",
|
||||
"couldnt've": "couldn't've",
|
||||
"didnt": "didn't",
|
||||
"doesnt": "doesn't",
|
||||
"dont": "don't",
|
||||
"hadnt": "hadn't",
|
||||
"hadnt've": "hadn't've",
|
||||
"hadn'tve": "hadn't've",
|
||||
"hasnt": "hasn't",
|
||||
"havent": "haven't",
|
||||
"hed": "he'd",
|
||||
"hed've": "he'd've",
|
||||
"he'dve": "he'd've",
|
||||
"hes": "he's",
|
||||
"howd": "how'd",
|
||||
"howll": "how'll",
|
||||
"hows": "how's",
|
||||
"Id've": "I'd've",
|
||||
"I'dve": "I'd've",
|
||||
"Im": "I'm",
|
||||
"Ive": "I've",
|
||||
"isnt": "isn't",
|
||||
"itd": "it'd",
|
||||
"itd've": "it'd've",
|
||||
"it'dve": "it'd've",
|
||||
"itll": "it'll",
|
||||
"let's": "let's",
|
||||
"maam": "ma'am",
|
||||
"mightnt": "mightn't",
|
||||
"mightnt've": "mightn't've",
|
||||
"mightn'tve": "mightn't've",
|
||||
"mightve": "might've",
|
||||
"mustnt": "mustn't",
|
||||
"mustve": "must've",
|
||||
"neednt": "needn't",
|
||||
"notve": "not've",
|
||||
"oclock": "o'clock",
|
||||
"oughtnt": "oughtn't",
|
||||
"ow's'at": "'ow's'at",
|
||||
"'ows'at": "'ow's'at",
|
||||
"'ow'sat": "'ow's'at",
|
||||
"shant": "shan't",
|
||||
"shed've": "she'd've",
|
||||
"she'dve": "she'd've",
|
||||
"she's": "she's",
|
||||
"shouldve": "should've",
|
||||
"shouldnt": "shouldn't",
|
||||
"shouldnt've": "shouldn't've",
|
||||
"shouldn'tve": "shouldn't've",
|
||||
"somebody'd": "somebodyd",
|
||||
"somebodyd've": "somebody'd've",
|
||||
"somebody'dve": "somebody'd've",
|
||||
"somebodyll": "somebody'll",
|
||||
"somebodys": "somebody's",
|
||||
"someoned": "someone'd",
|
||||
"someoned've": "someone'd've",
|
||||
"someone'dve": "someone'd've",
|
||||
"someonell": "someone'll",
|
||||
"someones": "someone's",
|
||||
"somethingd": "something'd",
|
||||
"somethingd've": "something'd've",
|
||||
"something'dve": "something'd've",
|
||||
"somethingll": "something'll",
|
||||
"thats": "that's",
|
||||
"thered": "there'd",
|
||||
"thered've": "there'd've",
|
||||
"there'dve": "there'd've",
|
||||
"therere": "there're",
|
||||
"theres": "there's",
|
||||
"theyd": "they'd",
|
||||
"theyd've": "they'd've",
|
||||
"they'dve": "they'd've",
|
||||
"theyll": "they'll",
|
||||
"theyre": "they're",
|
||||
"theyve": "they've",
|
||||
"twas": "'twas",
|
||||
"wasnt": "wasn't",
|
||||
"wed've": "we'd've",
|
||||
"we'dve": "we'd've",
|
||||
"weve": "we've",
|
||||
"werent": "weren't",
|
||||
"whatll": "what'll",
|
||||
"whatre": "what're",
|
||||
"whats": "what's",
|
||||
"whatve": "what've",
|
||||
"whens": "when's",
|
||||
"whered": "where'd",
|
||||
"wheres": "where's",
|
||||
"whereve": "where've",
|
||||
"whod": "who'd",
|
||||
"whod've": "who'd've",
|
||||
"who'dve": "who'd've",
|
||||
"wholl": "who'll",
|
||||
"whos": "who's",
|
||||
"whove": "who've",
|
||||
"whyll": "why'll",
|
||||
"whyre": "why're",
|
||||
"whys": "why's",
|
||||
"wont": "won't",
|
||||
"wouldve": "would've",
|
||||
"wouldnt": "wouldn't",
|
||||
"wouldnt've": "wouldn't've",
|
||||
"wouldn'tve": "wouldn't've",
|
||||
"yall": "y'all",
|
||||
"yall'll": "y'all'll",
|
||||
"y'allll": "y'all'll",
|
||||
"yall'd've": "y'all'd've",
|
||||
"y'alld've": "y'all'd've",
|
||||
"y'all'dve": "y'all'd've",
|
||||
"youd": "you'd",
|
||||
"youd've": "you'd've",
|
||||
"you'dve": "you'd've",
|
||||
"youll": "you'll",
|
||||
"youre": "you're",
|
||||
"youve": "you've",
|
||||
}
|
||||
|
||||
NUMBER_MAP = {
|
||||
"none": "0",
|
||||
"zero": "0",
|
||||
"one": "1",
|
||||
"two": "2",
|
||||
"three": "3",
|
||||
"four": "4",
|
||||
"five": "5",
|
||||
"six": "6",
|
||||
"seven": "7",
|
||||
"eight": "8",
|
||||
"nine": "9",
|
||||
"ten": "10",
|
||||
}
|
||||
ARTICLES = ["a", "an", "the"]
|
||||
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
||||
COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
|
||||
PUNCTUATIONS = [
|
||||
";",
|
||||
r"/",
|
||||
"[",
|
||||
"]",
|
||||
'"',
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
"=",
|
||||
"+",
|
||||
"\\",
|
||||
"_",
|
||||
"-",
|
||||
">",
|
||||
"<",
|
||||
"@",
|
||||
"`",
|
||||
",",
|
||||
"?",
|
||||
"!",
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def word_tokenize(self, word):
|
||||
word = word.lower()
|
||||
word = word.replace(",", "").replace("?", "").replace("'s", " 's")
|
||||
return word.strip()
|
||||
|
||||
def process_punctuation(self, in_text):
|
||||
out_text = in_text
|
||||
for p in self.PUNCTUATIONS:
|
||||
if (p + " " in in_text or " " + p in in_text) or (
|
||||
re.search(self.COMMA_STRIP, in_text) is not None
|
||||
):
|
||||
out_text = out_text.replace(p, "")
|
||||
else:
|
||||
out_text = out_text.replace(p, " ")
|
||||
out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
|
||||
return out_text
|
||||
|
||||
def process_digit_article(self, in_text):
|
||||
out_text = []
|
||||
temp_text = in_text.lower().split()
|
||||
for word in temp_text:
|
||||
word = self.NUMBER_MAP.setdefault(word, word)
|
||||
if word not in self.ARTICLES:
|
||||
out_text.append(word)
|
||||
else:
|
||||
pass
|
||||
for word_id, word in enumerate(out_text):
|
||||
if word in self.CONTRACTIONS:
|
||||
out_text[word_id] = self.CONTRACTIONS[word]
|
||||
out_text = " ".join(out_text)
|
||||
return out_text
|
||||
|
||||
def __call__(self, item):
|
||||
item = self.word_tokenize(item)
|
||||
item = item.replace("\n", " ").replace("\t", " ").strip()
|
||||
item = self.process_punctuation(item)
|
||||
item = self.process_digit_article(item)
|
||||
return item
|
||||
p = EvalAIAnswerProcessor()
|
||||
for line in tqdm(hfa):
|
||||
line['answer'] = p(line['answer'])
|
||||
|
|
@ -1,2 +1,2 @@
|
|||
from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
|
||||
from .configuration_mplug_owl2 import MPLUGOwl2Config
|
||||
from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM, MPLUGOwl2QWenForCausalLM
|
||||
from .configuration_mplug_owl2 import MPLUGOwl2Config,MPLUGOwl2QwenConfig
|
||||
|
|
|
|||
|
|
@ -20,10 +20,16 @@ import shutil
|
|||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
||||
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
|
||||
import torch
|
||||
from mplug_owl2.model import *
|
||||
from extensions.sd_smartprocess.mplug_owl2.model import *
|
||||
from icecream import ic
|
||||
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
|
||||
kwargs = {"device_map": device_map}
|
||||
|
||||
from extensions.sd_smartprocess.mplug_owl2 import MPLUGOwl2LlamaForCausalLM
|
||||
from extensions.sd_smartprocess.mplug_owl2.model import MPLUGOwl2QWenForCausalLM
|
||||
|
||||
|
||||
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto",
|
||||
device="cuda", **kwargs):
|
||||
kwargs = {"device_map": device_map, "ignore_mismatched_sizes": False, **kwargs}
|
||||
|
||||
if device != "cuda":
|
||||
kwargs['device_map'] = {"": device}
|
||||
|
|
@ -43,20 +49,29 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|||
if 'mplug_owl2' in model_name.lower():
|
||||
# Load LLaVA model
|
||||
if 'lora' in model_name.lower() and model_base is None:
|
||||
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
|
||||
warnings.warn(
|
||||
'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
|
||||
if 'lora' in model_name.lower() and model_base is not None:
|
||||
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
||||
print('Loading mPLUG-Owl2 from base model...')
|
||||
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
|
||||
if 'mplug_owl2_1' in model_name.lower():
|
||||
model = MPLUGOwl2QWenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
|
||||
config=lora_cfg_pretrained, **kwargs)
|
||||
else:
|
||||
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
|
||||
config=lora_cfg_pretrained, **kwargs)
|
||||
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
||||
if model.lm_head.weight.shape[0] != token_num:
|
||||
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
||||
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
||||
model.lm_head.weight = torch.nn.Parameter(
|
||||
torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
||||
model.model.embed_tokens.weight = torch.nn.Parameter(
|
||||
torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
||||
|
||||
print('Loading additional mPLUG-Owl2 weights...')
|
||||
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
||||
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
||||
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'),
|
||||
map_location='cpu')
|
||||
else:
|
||||
# this is probably from HF Hub
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
|
@ -66,10 +81,13 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|||
filename=filename,
|
||||
subfolder=subfolder)
|
||||
return torch.load(cache_file, map_location='cpu')
|
||||
|
||||
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
||||
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
||||
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
|
||||
non_lora_trainables.items()}
|
||||
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
||||
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
||||
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in
|
||||
non_lora_trainables.items()}
|
||||
model.load_state_dict(non_lora_trainables, strict=False)
|
||||
|
||||
from peft import PeftModel
|
||||
|
|
@ -83,16 +101,24 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|||
print('Loading mPLUG-Owl2 from base model...')
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
||||
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
||||
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
||||
if 'mplug_owl2_1' in model_name.lower():
|
||||
model = MPLUGOwl2QWenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
|
||||
config=cfg_pretrained, **kwargs)
|
||||
else:
|
||||
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
|
||||
config=cfg_pretrained, **kwargs)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
|
||||
if 'mplug_owl2_1' in model_name.lower():
|
||||
model = MPLUGOwl2QWenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
||||
else:
|
||||
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
||||
else:
|
||||
# Load language model
|
||||
if model_base is not None:
|
||||
# PEFT model
|
||||
from peft import PeftModel
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
||||
print(f"Loading LoRA weights from {model_path}")
|
||||
model = PeftModel.from_pretrained(model, model_path)
|
||||
|
|
@ -102,12 +128,11 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|||
model.to(torch.float16)
|
||||
else:
|
||||
use_fast = False
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
||||
|
||||
|
||||
vision_tower = model.get_model().vision_model
|
||||
vision_tower.to(device=device, dtype=torch.float16)
|
||||
# vision_tower.to(device=device, dtype=torch.float16)
|
||||
image_processor = CLIPImageProcessor.from_pretrained(model_path)
|
||||
|
||||
if hasattr(model.config, "max_sequence_length"):
|
||||
|
|
@ -115,4 +140,4 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|||
else:
|
||||
context_len = 2048
|
||||
|
||||
return tokenizer, model, image_processor, context_len
|
||||
return tokenizer, model, image_processor, context_len
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
from transformers.utils import logging
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from .configuration_qwen import QWenConfig
|
||||
|
||||
class LlamaConfig(PretrainedConfig):
|
||||
r"""
|
||||
|
|
@ -229,9 +229,13 @@ class MplugOwlVisionConfig(PretrainedConfig):
|
|||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
use_flash_attn=False,
|
||||
use_post_layernorm=True,
|
||||
use_cls_token=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.use_cls_token=use_cls_token
|
||||
self.use_post_layernorm=use_post_layernorm
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.projection_dim = projection_dim
|
||||
|
|
@ -269,6 +273,8 @@ class MplugOwlVisualAbstractorConfig(PretrainedConfig):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
add_v2t_pos_emb=False,
|
||||
use_cls_token=True,
|
||||
num_learnable_queries=64,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=6,
|
||||
|
|
@ -282,6 +288,8 @@ class MplugOwlVisualAbstractorConfig(PretrainedConfig):
|
|||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.use_cls_token=use_cls_token
|
||||
self.add_v2t_pos_emb=add_v2t_pos_emb
|
||||
self.hidden_size = hidden_size
|
||||
self.num_learnable_queries = num_learnable_queries
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
|
@ -327,6 +335,17 @@ class MPLUGOwl2Config(LlamaConfig):
|
|||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
class MPLUGOwl2QwenConfig(QWenConfig):
|
||||
model_type = "mplug_owl2_1"
|
||||
def __init__(self, visual_config=None, **kwargs):
|
||||
if visual_config is None:
|
||||
self.visual_config = DEFAULT_VISUAL_CONFIG
|
||||
else:
|
||||
self.visual_config = visual_config
|
||||
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
print(MplugOwlVisionConfig().to_dict())
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) Alibaba Cloud.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class QWenConfig(PretrainedConfig):
|
||||
model_type = "qwen"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
multiway=False,
|
||||
vocab_size=151936,
|
||||
hidden_size=4096,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
emb_dropout_prob=0.0,
|
||||
attn_dropout_prob=0.0,
|
||||
layer_norm_epsilon=1e-6,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=8192,
|
||||
scale_attn_weights=True,
|
||||
use_cache=True,
|
||||
bf16=False,
|
||||
fp16=False,
|
||||
fp32=False,
|
||||
kv_channels=128,
|
||||
rotary_pct=1.0,
|
||||
rotary_emb_base=10000,
|
||||
use_dynamic_ntk=True,
|
||||
use_logn_attn=True,
|
||||
use_flash_attn="auto",
|
||||
intermediate_size=22016,
|
||||
no_bias=True,
|
||||
tie_word_embeddings=False,
|
||||
use_cache_quantization=False,
|
||||
use_cache_kernel=False,
|
||||
softmax_in_fp32=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.multiway = multiway
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.emb_dropout_prob = emb_dropout_prob
|
||||
self.attn_dropout_prob = attn_dropout_prob
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.scale_attn_weights = scale_attn_weights
|
||||
self.use_cache = use_cache
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.bf16 = bf16
|
||||
self.fp16 = fp16
|
||||
self.fp32 = fp32
|
||||
self.kv_channels = kv_channels
|
||||
self.rotary_pct = rotary_pct
|
||||
self.rotary_emb_base = rotary_emb_base
|
||||
self.use_dynamic_ntk = use_dynamic_ntk
|
||||
self.use_logn_attn = use_logn_attn
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.no_bias = no_bias
|
||||
self.use_cache_quantization = use_cache_quantization
|
||||
self.use_cache_kernel = use_cache_kernel
|
||||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs
|
||||
)
|
||||
|
|
@ -16,6 +16,7 @@ from transformers.utils import logging
|
|||
|
||||
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from .configuration_mplug_owl2 import LlamaConfig
|
||||
from .multiway import MultiwayNetwork
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
|
@ -30,33 +31,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class MultiwayNetwork(nn.Module):
|
||||
|
||||
def __init__(self, module_provider, num_multiway=2):
|
||||
super(MultiwayNetwork, self).__init__()
|
||||
|
||||
self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
|
||||
|
||||
def forward(self, hidden_states, multiway_indices):
|
||||
|
||||
if len(self.multiway) == 1:
|
||||
return self.multiway[0](hidden_states)
|
||||
|
||||
output_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
for idx, subway in enumerate(self.multiway):
|
||||
local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
|
||||
hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
|
||||
if hidden.numel():
|
||||
output = subway(hidden)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
output = output.squeeze(1)
|
||||
output_hidden_states[local_indices] = output
|
||||
|
||||
return output_hidden_states.contiguous()
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
|
@ -142,7 +116,7 @@ class LlamaAttention(nn.Module):
|
|||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len, position_ids=position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
|
|
@ -193,19 +167,15 @@ class LlamaAttention(nn.Module):
|
|||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
def __init__(self, config: LlamaConfig, annoying_param):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = LlamaAttention(config=config)
|
||||
# Check if LlamaMLP takes one or three args
|
||||
num_llama_args = len(inspect.signature(LlamaMLP.__init__).parameters)
|
||||
if num_llama_args == 1:
|
||||
self.mlp = LlamaMLP(config)
|
||||
elif num_llama_args == 3:
|
||||
self.mlp = LlamaMLP(config.hidden_size, config.intermediate_size, config.hidden_act)
|
||||
else:
|
||||
raise ValueError(f"Invalid number of arguments for LlamaMLP: {num_llama_args}")
|
||||
self.mlp = LlamaMLP(config)
|
||||
mlp_kwargs = {'config': config, "hidden_size": config.hidden_size,
|
||||
"intermediate_size": config.intermediate_size, "hidden_act": config.hidden_act}
|
||||
valid_params = set(inspect.signature(LlamaMLP.__init__).parameters.keys()) - {'self'}
|
||||
mlp_kwargs = {k: v for k, v in mlp_kwargs.items() if k in valid_params}
|
||||
self.mlp = LlamaMLP(**mlp_kwargs)
|
||||
self.input_layernorm = MultiwayNetwork(module_provider=partial(
|
||||
LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
|
||||
))
|
||||
|
|
|
|||
|
|
@ -18,15 +18,15 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, LlamaModel, LlamaForCausalLM
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from .configuration_mplug_owl2 import MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig
|
||||
from .visual_encoder import MplugOwlVisionModel, MplugOwlVisualAbstractorModel
|
||||
from extensions.sd_smartprocess.mplug_owl2.constants import IMAGE_TOKEN_INDEX, IGNORE_INDEX
|
||||
from .configuration_mplug_owl2 import MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig, \
|
||||
MPLUGOwl2QwenConfig
|
||||
from .modeling_llama2 import replace_llama_modality_adaptive
|
||||
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, IGNORE_INDEX
|
||||
from icecream import ic
|
||||
from .modeling_qwen import QWenLMHeadModel, QWenModel
|
||||
from .visual_encoder import MplugOwlVisionModel, MplugOwlVisualAbstractorModel
|
||||
|
||||
|
||||
class MPLUGOwl2MetaModel:
|
||||
|
|
@ -67,8 +67,10 @@ class MPLUGOwl2MetaForCausalLM(ABC):
|
|||
):
|
||||
if images is None or input_ids.shape[1] == 1:
|
||||
if past_key_values is not None and images is not None and input_ids.shape[1] == 1:
|
||||
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
|
||||
dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
# print(attention_mask)
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
|
||||
dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
multiway_indices = torch.zeros_like(input_ids).long().to(self.device)
|
||||
return input_ids, multiway_indices, attention_mask, past_key_values, None, labels
|
||||
|
||||
|
|
@ -210,12 +212,82 @@ class MPLUGOwl2MetaForCausalLM(ABC):
|
|||
return None, new_modality_indicators, attention_mask, past_key_values, new_input_embeds, new_labels
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1
|
||||
)
|
||||
return mask[None, None, :, :].expand(
|
||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
class MPLUGOwl2LlamaModel(MPLUGOwl2MetaModel, LlamaModel):
|
||||
config_class = MPLUGOwl2Config
|
||||
|
||||
def __init__(self, config: MPLUGOwl2Config):
|
||||
super(MPLUGOwl2LlamaModel, self).__init__(config)
|
||||
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
past_key_values_length=past_key_values_length,
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
).to(inputs_embeds.device)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
class MPLUGOwl2QWenModel(MPLUGOwl2MetaModel, QWenModel):
|
||||
config_class = MPLUGOwl2QwenConfig
|
||||
|
||||
def __init__(self, config: MPLUGOwl2QwenConfig):
|
||||
super(MPLUGOwl2QWenModel, self).__init__(config)
|
||||
|
||||
|
||||
class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
||||
config_class = MPLUGOwl2Config
|
||||
|
|
@ -229,13 +301,17 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def encode_images(self, images):
|
||||
image_features = self.get_model().vision_model(images).last_hidden_state
|
||||
image_features = self.get_model().visual_abstractor(encoder_hidden_states=image_features).last_hidden_state
|
||||
return image_features
|
||||
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
# modality_indicators: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
|
@ -318,8 +394,145 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
|
|||
return model_inputs
|
||||
|
||||
|
||||
class MPLUGOwl2QWenForCausalLM(QWenLMHeadModel, MPLUGOwl2MetaForCausalLM):
|
||||
config_class = MPLUGOwl2QwenConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super(QWenLMHeadModel, self).__init__(config)
|
||||
from .modeling_qwen import SUPPORT_BF16, logger, SUPPORT_FP16, SUPPORT_CUDA, _import_flash_attn
|
||||
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
|
||||
|
||||
if autoset_precision:
|
||||
if SUPPORT_BF16:
|
||||
logger.warn(
|
||||
"The model is automatically converting to bf16 for faster inference. "
|
||||
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
|
||||
)
|
||||
config.bf16 = True
|
||||
elif SUPPORT_FP16:
|
||||
logger.warn(
|
||||
"The model is automatically converting to fp16 for faster inference. "
|
||||
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
|
||||
)
|
||||
config.fp16 = True
|
||||
else:
|
||||
config.fp32 = True
|
||||
|
||||
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
|
||||
logger.warn(
|
||||
"Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
|
||||
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
|
||||
logger.warn(
|
||||
"Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
|
||||
if config.fp32:
|
||||
if SUPPORT_BF16:
|
||||
logger.warn(
|
||||
"Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
|
||||
elif SUPPORT_FP16:
|
||||
logger.warn(
|
||||
"Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
|
||||
|
||||
if config.use_flash_attn == "auto":
|
||||
if config.bf16 or config.fp16:
|
||||
logger.warn("Try importing flash-attention for faster inference...")
|
||||
config.use_flash_attn = True
|
||||
else:
|
||||
config.use_flash_attn = False
|
||||
if config.use_flash_attn and config.fp32:
|
||||
logger.warn("Flash attention will be disabled because it does NOT support fp32.")
|
||||
|
||||
if config.use_flash_attn:
|
||||
_import_flash_attn()
|
||||
|
||||
self.transformer = MPLUGOwl2QWenModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
if config.bf16:
|
||||
self.transformer.bfloat16()
|
||||
self.lm_head.bfloat16()
|
||||
if config.fp16:
|
||||
self.transformer.half()
|
||||
self.lm_head.half()
|
||||
self.post_init()
|
||||
|
||||
def get_model(self):
|
||||
return self.transformer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
images=None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
input_ids, modality_indicators, attention_mask, past_key_values, inputs_embeds, labels = \
|
||||
self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
modality_indicators=modality_indicators,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model/pipeline parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
AutoConfig.register("mplug_owl2", MPLUGOwl2Config)
|
||||
AutoModelForCausalLM.register(MPLUGOwl2Config, MPLUGOwl2LlamaForCausalLM)
|
||||
AutoConfig.register("mplug_owl2_1", MPLUGOwl2QwenConfig)
|
||||
AutoModelForCausalLM.register(MPLUGOwl2QwenConfig, MPLUGOwl2QWenForCausalLM)
|
||||
|
||||
replace_llama_modality_adaptive()
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,35 @@
|
|||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MultiwayNetwork(nn.Module):
|
||||
|
||||
def __init__(self, module_provider, num_multiway=2, out_features=None):
|
||||
super(MultiwayNetwork, self).__init__()
|
||||
|
||||
self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
|
||||
self.out_features=out_features
|
||||
def forward(self, hidden_states, multiway_indices):
|
||||
|
||||
if len(self.multiway) == 1:
|
||||
return self.multiway[0](hidden_states)
|
||||
if self.out_features:
|
||||
output_hidden_states = torch.empty(
|
||||
hidden_states.size(0), hidden_states.size(1), self.out_features,
|
||||
dtype=hidden_states.dtype
|
||||
).to(hidden_states.device)
|
||||
else:
|
||||
output_hidden_states = torch.empty_like(hidden_states)
|
||||
for idx, subway in enumerate(self.multiway):
|
||||
local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
|
||||
hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
|
||||
if hidden.numel():
|
||||
output = subway(hidden)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
output = output.squeeze(1)
|
||||
output_hidden_states[local_indices] = output
|
||||
|
||||
return output_hidden_states.contiguous()
|
||||
|
|
@ -10,7 +10,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from icecream import ic
|
||||
|
||||
import torch.nn.functional as F
|
||||
def get_abs_pos(abs_pos, tgt_size):
|
||||
# abs_pos: L, C
|
||||
# tgt_size: M
|
||||
|
|
@ -29,6 +29,7 @@ def get_abs_pos(abs_pos, tgt_size):
|
|||
else:
|
||||
return abs_pos
|
||||
|
||||
|
||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
||||
"""
|
||||
|
|
@ -88,8 +89,10 @@ class MplugOwlVisionEmbeddings(nn.Module):
|
|||
self.hidden_size = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
|
||||
if config.use_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
|
||||
else:
|
||||
self.cls_token = None
|
||||
|
||||
self.patch_embed = nn.Conv2d(
|
||||
in_channels=3,
|
||||
|
|
@ -99,20 +102,25 @@ class MplugOwlVisionEmbeddings(nn.Module):
|
|||
bias=False,
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
|
||||
self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))
|
||||
|
||||
if self.cls_token is not None:
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))
|
||||
else:
|
||||
self.num_patches = 256
|
||||
self.position_embedding = nn.Parameter(torch.randn(256, self.hidden_size))
|
||||
self.pre_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.size(0)
|
||||
image_embeds = self.patch_embed(pixel_values)
|
||||
image_embeds = image_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.cls_token.expand(batch_size, 1, -1).to(image_embeds.dtype)
|
||||
embeddings = torch.cat([class_embeds, image_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(image_embeds.dtype)
|
||||
if self.cls_token is not None:
|
||||
class_embeds = self.cls_token.expand(batch_size, 1, -1).to(image_embeds.dtype)
|
||||
embeddings = torch.cat([class_embeds, image_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(image_embeds.dtype)
|
||||
else:
|
||||
embeddings = image_embeds
|
||||
embeddings = embeddings + get_abs_pos(self.position_embedding,embeddings.size(1))
|
||||
embeddings = self.pre_layernorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
|
@ -221,16 +229,17 @@ class MplugOwlVisionAttention(nn.Module):
|
|||
return outputs
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
# class QuickGELU(nn.Module):
|
||||
# def forward(self, x: torch.Tensor):
|
||||
# return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class MplugOwlMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = QuickGELU()
|
||||
from transformers.activations import ACT2FN
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
|
|
@ -391,10 +400,20 @@ class MplugOwlVisionModel(PreTrainedModel):
|
|||
|
||||
self.embeddings = MplugOwlVisionEmbeddings(config)
|
||||
self.encoder = MplugOwlVisionEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
if config.use_post_layernorm:
|
||||
self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
|
||||
else:
|
||||
self.post_layernorm = None
|
||||
self._no_split_modules = self._get_no_split_modules("")
|
||||
self.post_init()
|
||||
|
||||
def _get_no_split_modules(self, device_map: str):
|
||||
if self._no_split_modules is None:
|
||||
self._no_split_modules = {
|
||||
"embeddings": self.embeddings,
|
||||
"encoder": self.encoder,
|
||||
}
|
||||
return self._no_split_modules
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -426,10 +445,12 @@ class MplugOwlVisionModel(PreTrainedModel):
|
|||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
if self.post_layernorm:
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
if self.post_layernorm:
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
|
@ -499,7 +520,7 @@ class MplugOwlVisualAbstractorMultiHeadAttention(nn.Module):
|
|||
)
|
||||
self.register_buffer(
|
||||
'k_pos_embed',
|
||||
torch.from_numpy(get_2d_sincos_pos_embed(config.hidden_size, grids, cls_token=True)).float()
|
||||
torch.from_numpy(get_2d_sincos_pos_embed(config.hidden_size, grids, cls_token=config.use_cls_token)).float()
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -585,7 +606,7 @@ class MplugOwlVisualAbstractorMultiHeadAttention(nn.Module):
|
|||
class MplugOwlVisualAbstractorCrossOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
dim = config.hidden_size
|
||||
dim = config.encoder_hidden_size
|
||||
self.out_proj = nn.Linear(dim, dim, bias=True)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.mlp = MplugOwlVisualAbstractorMLP(config)
|
||||
|
|
@ -602,9 +623,19 @@ class MplugOwlVisualAbstractorAttention(nn.Module):
|
|||
self.attention = MplugOwlVisualAbstractorMultiHeadAttention(config)
|
||||
self.output = MplugOwlVisualAbstractorCrossOutput(config)
|
||||
self.pruned_heads = set()
|
||||
self.norm1 = nn.LayerNorm(config.hidden_size)
|
||||
self.normk = nn.LayerNorm(config.hidden_size)
|
||||
self.norm1 = nn.LayerNorm(config.encoder_hidden_size)
|
||||
self.normk = nn.LayerNorm(config.encoder_hidden_size)
|
||||
|
||||
self.add_pos_embed = config.add_v2t_pos_emb
|
||||
if self.add_pos_embed:
|
||||
self.q_pos_embed = nn.Parameter(
|
||||
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(config.encoder_hidden_size, np.arange(config.num_learnable_queries, dtype=np.float32))).float()
|
||||
).requires_grad_(False)
|
||||
|
||||
self.k_pos_embed = nn.Parameter(
|
||||
torch.from_numpy(get_2d_sincos_pos_embed(config.encoder_hidden_size, config.grid_size, cls_token=config.cls_token)).float()
|
||||
).requires_grad_(False)
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
|
|
@ -757,14 +788,23 @@ class MplugOwlVisualAbstractorModel(PreTrainedModel):
|
|||
def __init__(self, config, language_hidden_size):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.encoder = MplugOwlVisualAbstractorEncoder(config)
|
||||
self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size)
|
||||
self._no_split_modules = self._get_no_split_modules("")
|
||||
|
||||
self.query_embeds = torch.nn.Parameter(torch.randn(1, config.num_learnable_queries, config.hidden_size))
|
||||
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
|
||||
|
||||
self.post_init()
|
||||
|
||||
def _get_no_split_modules(self, device_map: str):
|
||||
if self._no_split_modules is None:
|
||||
self._no_split_modules = {
|
||||
"encoder": self.encoder,
|
||||
"visual_fc": self.visual_fc,
|
||||
}
|
||||
return self._no_split_modules
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
|
|
@ -849,7 +889,6 @@ class MplugOwlVisualAbstractorModel(PreTrainedModel):
|
|||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
query_embeds = self.query_embeds.repeat(encoder_hidden_states.shape[0], 1, 1)
|
||||
embedding_output = query_embeds
|
||||
input_shape = embedding_output.size()[:-1]
|
||||
|
|
|
|||
|
|
@ -12,4 +12,4 @@ ruamel.yaml
|
|||
markdown2
|
||||
sconf
|
||||
tensorboardX
|
||||
transformers>=4.35.0
|
||||
transformers==4.38.1
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
|
@ -37,7 +38,7 @@ registry = InterrogatorRegistry()
|
|||
|
||||
int_dict = registry.list_interrogators()
|
||||
print(f"Found {len(int_dict.keys())} interrogators: {int_dict}")
|
||||
natural_captioner_names = ["CLIP", "BLIP", "LLAVA"]
|
||||
natural_captioner_names = ["CLIP", "BLIP", "MPLUG2"]
|
||||
default_captioners = ["Swin"]
|
||||
|
||||
natural_captioners = {}
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ clip_interrogator = None
|
|||
crop_clip = None
|
||||
image_interrogators = {}
|
||||
global_unpickler = None
|
||||
|
||||
image_features = None
|
||||
|
||||
def printi(message):
|
||||
shared.state.textinfo = message
|
||||
|
|
@ -77,18 +77,23 @@ def save_img_caption(image_path: str, img_caption: str, params: ProcessParams):
|
|||
|
||||
|
||||
def list_features():
|
||||
# Create buffer for pilinfo() to write into rather than stdout
|
||||
buffer = StringIO()
|
||||
features.pilinfo(out=buffer)
|
||||
pil_features = []
|
||||
# Parse and analyse lines
|
||||
for line in buffer.getvalue().splitlines():
|
||||
if "Extensions:" in line:
|
||||
ext_list = line.split(": ")[1]
|
||||
extensions = ext_list.split(", ")
|
||||
for extension in extensions:
|
||||
if extension not in pil_features:
|
||||
pil_features.append(extension)
|
||||
global image_features
|
||||
if image_features is None:
|
||||
# Create buffer for pilinfo() to write into rather than stdout
|
||||
buffer = StringIO()
|
||||
features.pilinfo(out=buffer)
|
||||
pil_features = []
|
||||
# Parse and analyse lines
|
||||
for line in buffer.getvalue().splitlines():
|
||||
if "Extensions:" in line:
|
||||
ext_list = line.split(": ")[1]
|
||||
extensions = ext_list.split(", ")
|
||||
for extension in extensions:
|
||||
if extension not in pil_features:
|
||||
pil_features.append(extension)
|
||||
image_features = pil_features
|
||||
else:
|
||||
pil_features = image_features
|
||||
return pil_features
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue