diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b2124b1e..8ad0b182a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,7 @@ - use shared T5 text encoder for video models when possible - unified video save code across all video models also avoids creation of temporary files for each frame unless user wants to save them + - unified prompt enhance code across all video models - add job state tracking for video generation - improve offloading for **ltx** and **wan** - fix model selection in ltx tab diff --git a/modules/interrogate/vqa.py b/modules/interrogate/vqa.py index baebb62c2..0821d5c4f 100644 --- a/modules/interrogate/vqa.py +++ b/modules/interrogate/vqa.py @@ -636,12 +636,11 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image: image = image[0] if len(image) > 0 else None if isinstance(image, dict) and 'name' in image: image = Image.open(image['name']) - if image is None: - return '' - if image.width > 768 or image.height > 768: - image.thumbnail((768, 768), Image.Resampling.LANCZOS) - if image.mode != 'RGB': - image = image.convert('RGB') + if isinstance(image, Image.Image): + if image.width > 768 or image.height > 768: + image.thumbnail((768, 768), Image.Resampling.LANCZOS) + if image.mode != 'RGB': + image = image.convert('RGB') if prompt is not None and len(prompt) > 0: question = prompt if len(question) < 2: @@ -664,9 +663,9 @@ def interrogate(question:str='', system_prompt:str=None, prompt:str=None, image: if vqa_model is None: shared.log.error(f'Interrogate: type=vlm model="{model_name}" unknown') return '' - if image is None: - shared.log.error(f'Interrogate: type=vlm model="{model_name}" no input image') - return '' + # if image is None: + # shared.log.error(f'Interrogate: type=vlm model="{model_name}" no input image') + # return '' if 'git' in vqa_model.lower(): answer = git(question, image, vqa_model) diff --git a/modules/ui_video_vlm.py b/modules/ui_video_vlm.py index 685c6d8d8..c7ce3a229 100644 --- a/modules/ui_video_vlm.py +++ b/modules/ui_video_vlm.py @@ -46,7 +46,6 @@ def enhance_prompt(enable:bool, model:str=None, image=None, prompt:str='', syste system_prompt += system_prompts['nsfw_ok'] if nsfw else system_prompts['nsfw_no'] system_prompt += f" {system_prompts['suffix']} {system_prompts['example']}" shared.log.debug(f'Video prompt enhance: model="{model}" image={image} nsfw={nsfw} prompt="{prompt}"') - # shared.log.trace(f'Video prompt enhance: system="{system_prompt}"') answer = vqa.interrogate(question='', prompt=prompt, system_prompt=system_prompt, image=image, model_name=model, quiet=False) shared.log.debug(f'Video prompt enhance: answer="{answer}"') return answer diff --git a/modules/video_models/video_prompt.py b/modules/video_models/video_prompt.py new file mode 100644 index 000000000..bbb29c409 --- /dev/null +++ b/modules/video_models/video_prompt.py @@ -0,0 +1,21 @@ +from modules import shared, extra_networks, ui_video_vlm + + +def prepare_prompt(p, init_image, prompt:str, vlm_enhance:bool, vlm_model:str, vlm_system_prompt:str): + p.prompt = shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles) + p.negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles) + shared.prompt_styles.apply_styles_to_extra(p) + p.prompts, p.network_data = extra_networks.parse_prompts([p.prompt]) + extra_networks.activate(p) + prompt = p.prompts[0] + + new_prompt = ui_video_vlm.enhance_prompt( + enable=vlm_enhance, + model=vlm_model, + image=init_image, + prompt=prompt, + system_prompt=vlm_system_prompt, + ) + if new_prompt is not None and len(new_prompt) > 0: + prompt = new_prompt + return prompt diff --git a/modules/video_models/video_run.py b/modules/video_models/video_run.py index 1dccc3611..40f7e154f 100644 --- a/modules/video_models/video_run.py +++ b/modules/video_models/video_run.py @@ -1,14 +1,14 @@ import os import time from modules import shared, errors, sd_models, processing, devices, images, ui_common -from modules.video_models import models_def, video_utils, video_load, video_vae, video_overrides, video_save +from modules.video_models import models_def, video_utils, video_load, video_vae, video_overrides, video_save, video_prompt debug = shared.log.trace if os.environ.get('SD_VIDEO_DEBUG', None) is not None else lambda *args, **kwargs: None def generate(*args, **kwargs): - task_id, ui_state, engine, model, prompt, negative, styles, width, height, frames, steps, sampler_index, sampler_shift, dynamic_shift, seed, guidance_scale, guidance_true, init_image, init_strength, last_image, vae_type, vae_tile_frames, mp4_fps, mp4_interpolate, mp4_codec, mp4_ext, mp4_opt, mp4_video, mp4_frames, mp4_sf, override_settings = args + task_id, ui_state, engine, model, prompt, negative, styles, width, height, frames, steps, sampler_index, sampler_shift, dynamic_shift, seed, guidance_scale, guidance_true, init_image, init_strength, last_image, vae_type, vae_tile_frames, mp4_fps, mp4_interpolate, mp4_codec, mp4_ext, mp4_opt, mp4_video, mp4_frames, mp4_sf, vlm_enhance, vlm_model, vlm_system_prompt, override_settings = args if engine is None or model is None or engine == 'None' or model == 'None': return video_utils.queue_err('model not selected') # videojob = shared.state.begin('Video') @@ -81,6 +81,8 @@ def generate(*args, **kwargs): shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) devices.torch_gc(force=True, reason='video') + prompt = video_prompt.prepare_prompt(p, init_image, prompt, vlm_enhance, vlm_model, vlm_system_prompt) + # set args processing.fix_seed(p) video_vae.set_vae_params(p) diff --git a/modules/video_models/video_ui.py b/modules/video_models/video_ui.py index faea8b6ef..cd78f45a2 100644 --- a/modules/video_models/video_ui.py +++ b/modules/video_models/video_ui.py @@ -163,6 +163,7 @@ def create_ui(prompt, negative, styles, overrides): init_image, init_strength, last_image, vae_type, vae_tile_frames, mp4_fps, mp4_interpolate, mp4_codec, mp4_ext, mp4_opt, mp4_video, mp4_frames, mp4_sf, + vlm_enhance, vlm_model, vlm_system_prompt, overrides, ] video_outputs = [