mirror of https://github.com/vladmandic/automatic
167 lines
6.3 KiB
Python
Executable File
167 lines
6.3 KiB
Python
Executable File
#!/bin/env python
|
|
|
|
import os
|
|
import gc
|
|
import json
|
|
import time
|
|
import argparse
|
|
import torch
|
|
import filetype
|
|
from PIL import Image
|
|
import transformers
|
|
from transformers import AutoProcessor, AutoModelForCausalLM
|
|
from transformers import BlipProcessor, BlipForConditionalGeneration
|
|
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
|
|
from util import log, Map
|
|
|
|
|
|
model = None
|
|
processor = None
|
|
extractor = None
|
|
dtype = torch.float32
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
options = Map({
|
|
'input': '',
|
|
'min': 8,
|
|
'max': 256,
|
|
'beams': 1,
|
|
'json': '',
|
|
'txt': False,
|
|
'tag': '',
|
|
'git': True,
|
|
'blip': True,
|
|
'precision': 'fp16',
|
|
'model': 'git',
|
|
})
|
|
|
|
|
|
def cleanup(s: str):
|
|
s = s.split('"')[0].split('.')[0].split(' that')[0]
|
|
s = s.split(' with a letter')[0].split(' with the number')[0].split(' with the word')[0]
|
|
s = s.replace('arafed image of ', '')
|
|
return s.replace('a ', '')
|
|
|
|
|
|
def load_model(args):
|
|
global model
|
|
global processor
|
|
global extractor
|
|
transformers.logging.set_verbosity_error()
|
|
if args.model == 'git':
|
|
model_name = "microsoft/git-large-textcaps"
|
|
if model is None:
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype)
|
|
model.to(device)
|
|
processor = AutoProcessor.from_pretrained(model_name, torch_dtype=dtype)
|
|
log.info( { 'interrogate loaded model': model_name })
|
|
elif args.model == 'blip':
|
|
model_name = "Salesforce/blip-image-captioning-large"
|
|
if model is None:
|
|
model = BlipForConditionalGeneration.from_pretrained(model_name, torch_dtype=dtype)
|
|
model.to(device)
|
|
processor = BlipProcessor.from_pretrained(model_name, torch_dtype=dtype)
|
|
log.info( { 'interrogate loaded model': model_name })
|
|
elif args.model == 'vit':
|
|
model_name = "nlpconnect/vit-gpt2-image-captioning"
|
|
if model is None:
|
|
model = VisionEncoderDecoderModel.from_pretrained(model_name, torch_dtype=dtype)
|
|
model.to(device)
|
|
extractor = ViTFeatureExtractor.from_pretrained(model_name, torch_dtype=dtype)
|
|
processor = AutoTokenizer.from_pretrained(model_name, torch_dtype=dtype)
|
|
log.info( { 'interrogate loaded model': model_name })
|
|
else:
|
|
log.info( { 'interrogate unknown model': args.model })
|
|
|
|
|
|
def interrogate_files(params, files):
|
|
args = Map({**options, **params})
|
|
data = [f for f in files if filetype.is_image(f)]
|
|
log.info({ 'interrogate files': len(files), 'images': len(data), 'args': args })
|
|
load_model(args)
|
|
metadata = {}
|
|
for image_path in data:
|
|
image = Image.open(image_path).convert('RGB')
|
|
caption = ''
|
|
if args.model == 'git':
|
|
inputs = processor(images=[image], return_tensors="pt").to(device)
|
|
ids = model.generate(pixel_values=inputs.pixel_values, num_beams=args.beams, min_length=args.min, max_length=args.max)
|
|
caption = processor.batch_decode(ids, skip_special_tokens=True)[0]
|
|
elif args.model == 'blip':
|
|
inputs = processor(image, return_tensors="pt").to(device, dtype)
|
|
ids = model.generate(**inputs, num_beams=args.beams, min_length=args.min, max_length=args.max)
|
|
caption = processor.decode(ids[0], skip_special_tokens=True)
|
|
elif args.model == 'vit':
|
|
inputs = extractor(images=[image], return_tensors="pt").pixel_values.to(device)
|
|
ids = model.generate(inputs, num_beams=args.beams, min_length=args.min, max_length=args.max)
|
|
caption = processor.batch_decode(ids, skip_special_tokens=True)[0]
|
|
else:
|
|
log.error({ 'interrogate unknown model': args.model })
|
|
|
|
caption = cleanup(caption)
|
|
tags = ''
|
|
if args.tag != '':
|
|
tags += args.tag + ','
|
|
tags += caption.split(' ')[0]
|
|
if args.txt:
|
|
with open(os.path.splitext(image_path)[0] + '.txt', "wt", encoding='utf-8') as f:
|
|
f.write(caption + "\n")
|
|
metadata[image_path] = { 'caption': caption, 'tags': tags }
|
|
log.info({ 'interrogate image': image_path, 'moodel': args.model, 'caption': caption, 'tags': tags })
|
|
|
|
if args.json != '':
|
|
with open(args.json, "wt", encoding='utf-8') as f:
|
|
f.write(json.dumps(metadata, indent=2) + "\n")
|
|
return metadata
|
|
|
|
|
|
def unload_model():
|
|
global processor
|
|
global model
|
|
global extractor
|
|
if model is not None:
|
|
del model
|
|
model = None
|
|
if processor is not None:
|
|
del processor
|
|
processor = None
|
|
if extractor is not None:
|
|
del extractor
|
|
extractor = None
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
with torch.no_grad():
|
|
torch.cuda.empty_cache()
|
|
with torch.cuda.device('cuda'):
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description = 'image interrogate')
|
|
parser.add_argument('input', type=str, nargs='*', help='input file or directory')
|
|
parser.add_argument('--model', default = 'git', choices = ['git', 'blip', 'vit'], help = "which model to use")
|
|
parser.add_argument("--min", type=int, default=8, help="min length of caption")
|
|
parser.add_argument("--max", type=int, default=256, help="max length of caption")
|
|
parser.add_argument("--beams", type=int, default=1, help="number of beams to use")
|
|
parser.add_argument("--json", type=str, default='', help="output json file")
|
|
parser.add_argument("--tag", type=str, default='', help="append tag")
|
|
parser.add_argument('--txt', default = False, action='store_true', help = "write captions to text files")
|
|
params = parser.parse_args()
|
|
log.info({ 'interrogate args': vars(params) })
|
|
if len(params.input) == 0:
|
|
parser.print_help()
|
|
exit(1)
|
|
files = []
|
|
for loc in params.input:
|
|
if os.path.isfile(loc):
|
|
files.append(loc)
|
|
elif os.path.isdir(loc):
|
|
for root, _sub_dirs, dir in os.walk(loc):
|
|
files = [os.path.join(root, f) for f in dir]
|
|
t0 = time.time()
|
|
metadata = interrogate_files(vars(params), files)
|
|
t1 = time.time()
|
|
log.info({ 'interrogate files': len(files), 'time': round(t1 - t0, 2) })
|
|
unload_model()
|