automatic/modules/caption/caption.py

52 lines
2.6 KiB
Python

import time
from PIL import Image
from modules import shared
from modules.logger import log
def caption(image):
if isinstance(image, list):
image = image[0] if len(image) > 0 else None
if isinstance(image, dict) and 'name' in image:
image = Image.open(image['name'])
if image is None:
log.error('Caption: no image provided')
return ''
t0 = time.time()
if shared.opts.caption_default_type == 'OpenCLiP':
log.info(f'Caption: type={shared.opts.caption_default_type} clip="{shared.opts.caption_openclip_model}" blip="{shared.opts.caption_openclip_blip_model}" mode="{shared.opts.caption_openclip_mode}"')
from modules.caption import openclip
openclip.load_captioner(clip_model=shared.opts.caption_openclip_model, blip_model=shared.opts.caption_openclip_blip_model)
openclip.update_caption_params()
prompt = openclip.caption(image, mode=shared.opts.caption_openclip_mode)
if shared.opts.caption_offload:
openclip.unload_clip_model()
log.debug(f'Caption: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
elif shared.opts.caption_default_type == 'Tagger':
log.info(f'Caption: type={shared.opts.caption_default_type} model="{shared.opts.waifudiffusion_model}"')
from modules.caption import tagger
prompt = tagger.tag(
image=image,
model_name=shared.opts.waifudiffusion_model,
general_threshold=shared.opts.tagger_threshold,
character_threshold=shared.opts.waifudiffusion_character_threshold,
include_rating=shared.opts.tagger_include_rating,
exclude_tags=shared.opts.tagger_exclude_tags,
max_tags=shared.opts.tagger_max_tags,
sort_alpha=shared.opts.tagger_sort_alpha,
use_spaces=shared.opts.tagger_use_spaces,
escape_brackets=shared.opts.tagger_escape_brackets,
)
log.debug(f'Caption: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
elif shared.opts.caption_default_type == 'VLM':
log.info(f'Caption: type={shared.opts.caption_default_type} vlm="{shared.opts.caption_vlm_model}" prompt="{shared.opts.caption_vlm_prompt}"')
from modules.caption import vqa
prompt = vqa.caption(image=image, model_name=shared.opts.caption_vlm_model, question=shared.opts.caption_vlm_prompt, prompt=None, system_prompt=shared.opts.caption_vlm_system)
log.debug(f'Caption: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
else:
log.error(f'Caption: type="{shared.opts.caption_default_type}" unknown')
return ''