refactor prompt set

Signed-off-by: vladmandic <mandic00@live.com>
pull/4545/head
vladmandic 2026-01-19 10:51:28 +01:00
parent 204bee6d2b
commit e8a158f4f5
6 changed files with 202 additions and 132 deletions

View File

@ -1,8 +1,8 @@
# Change Log for SD.Next
## Update for 2025-01-18
## Update for 2025-01-19
### Highlights for 2025-01-18
### Highlights for 2025-01-19
First release of 2026 brings quite a few new models: **Flux.2-Klein, Qwen-Image-2512, LTX-2-Dev, GLM-Image**
There are also improvements to SDNQ quantization engine, updated Prompt Enhance and many others.
@ -11,7 +11,7 @@ For full list of changes, see full changelog.
[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic)
### Details for 2025-01-18
### Details for 2025-01-19
- **Models**
- [Flux.2 Klein](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
@ -76,6 +76,7 @@ For full list of changes, see full changelog.
- split `reference.json`
- print system env on startup
- disable fallback on models with custom loaders
- refactor triggering of prompt parser and set secondary prompts when needed
- **Fixes**
- extension tab: update checker, date handling, formatting etc., thanks @awsr
- controlnet with non-english ui locales
@ -101,6 +102,7 @@ For full list of changes, see full changelog.
- netoworks icon/list view type switch, thanks @awsr
- lora skip with strength zero
- lora force unapply on change
- lora handle null description, thanks @CalamitousFelicitousness
## Update for 2025-12-26

@ -1 +1 @@
Subproject commit d83de1303abc930d6a2e8570e6c929a0139510f1
Subproject commit 6be3b994ad4002a4b6e280d03174286abd0c4218

View File

@ -7,9 +7,10 @@ import inspect
import torch
import numpy as np
from PIL import Image
from modules import shared, errors, sd_models, processing, processing_vae, processing_helpers, sd_hijack_hypertile, prompt_parser_diffusers, timer, extra_networks, sd_vae
from modules import shared, sd_models, processing, processing_vae, processing_helpers, sd_hijack_hypertile, extra_networks, sd_vae
from modules.processing_callbacks import diffusers_callback_legacy, diffusers_callback, set_callbacks_p
from modules.processing_helpers import resize_hires, fix_prompts, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, get_generator, set_latents, apply_circular # pylint: disable=unused-import
from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, get_generator, set_latents, apply_circular # pylint: disable=unused-import
from modules.processing_prompt import set_prompt
from modules.api import helpers
@ -191,6 +192,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
apply_circular(p.tiling, model)
args = {}
has_vae = hasattr(model, 'vae') or (hasattr(model, 'pipe') and hasattr(model.pipe, 'vae'))
cls = model.__class__.__name__
if hasattr(model, 'pipe') and not hasattr(model, 'no_recurse'): # recurse
model = model.pipe
has_vae = has_vae or hasattr(model, 'vae')
@ -204,90 +206,20 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
if debug_enabled:
debug_log(f'Process pipeline possible: {possible}')
prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompts(p, prompts, negative_prompts, prompts_2, negative_prompts_2)
steps = kwargs.get("num_inference_steps", None) or len(getattr(p, 'timesteps', ['1']))
clip_skip = kwargs.pop("clip_skip", 1)
prompt_attention, args = set_prompt(p, args, possible, cls, prompt_attention, steps, clip_skip, prompts, negative_prompts, prompts_2, negative_prompts_2)
if 'clip_skip' in possible:
if clip_skip == 1:
pass # clip_skip = None
else:
args['clip_skip'] = clip_skip - 1
if shared.opts.lora_apply_te:
extra_networks.activate(p, include=['text_encoder', 'text_encoder_2', 'text_encoder_3'])
parser = 'fixed'
prompt_attention = prompt_attention or shared.opts.prompt_attention
if (prompt_attention != 'fixed') and ('Onnx' not in model.__class__.__name__) and ('prompt' not in p.task_args) and (
'StableDiffusion' in model.__class__.__name__ or
'StableCascade' in model.__class__.__name__ or
('Flux' in model.__class__.__name__ and 'Flux2' not in model.__class__.__name__) or
'Chroma' in model.__class__.__name__ or
'HiDreamImagePipeline' in model.__class__.__name__
):
jobid = shared.state.begin('TE Encode')
try:
prompt_parser_diffusers.embedder = prompt_parser_diffusers.PromptEmbedder(prompts, negative_prompts, steps, clip_skip, p)
parser = shared.opts.prompt_attention
except Exception as e:
shared.log.error(f'Prompt parser encode: {e}')
if os.environ.get('SD_PROMPT_DEBUG', None) is not None:
errors.display(e, 'Prompt parser encode')
timer.process.record('prompt', reset=False)
shared.state.end(jobid)
else:
prompt_parser_diffusers.embedder = None
if 'prompt' in possible:
if 'OmniGen' in model.__class__.__name__:
prompts = [p.replace('|image|', '<img><|image_1|></img>') for p in prompts]
if ('HiDreamImage' in model.__class__.__name__) and (prompt_parser_diffusers.embedder is not None):
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
prompt_embeds = prompt_parser_diffusers.embedder('prompt_embeds')
args['prompt_embeds_t5'] = prompt_embeds[0]
args['prompt_embeds_llama3'] = prompt_embeds[1]
elif hasattr(model, 'text_encoder') and hasattr(model, 'tokenizer') and ('prompt_embeds' in possible) and (prompt_parser_diffusers.embedder is not None):
embeds = prompt_parser_diffusers.embedder('prompt_embeds')
if embeds is None:
shared.log.warning('Prompt parser encode: empty prompt embeds')
prompt_parser_diffusers.embedder = None
args['prompt'] = prompts
elif embeds.device == torch.device('meta'):
shared.log.warning('Prompt parser encode: embeds on meta device')
prompt_parser_diffusers.embedder = None
args['prompt'] = prompts
else:
args['prompt_embeds'] = embeds
if 'StableCascade' in model.__class__.__name__:
args['prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('positive_pooleds').unsqueeze(0)
elif 'XL' in model.__class__.__name__:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
elif 'StableDiffusion3' in model.__class__.__name__:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
elif 'Flux' in model.__class__.__name__:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
elif 'Chroma' in model.__class__.__name__:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
args['prompt_attention_mask'] = prompt_parser_diffusers.embedder('prompt_attention_masks')
else:
args['prompt'] = prompts
if 'negative_prompt' in possible:
if 'HiDreamImage' in model.__class__.__name__ and prompt_parser_diffusers.embedder is not None:
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
negative_prompt_embeds = prompt_parser_diffusers.embedder('negative_prompt_embeds')
args['negative_prompt_embeds_t5'] = negative_prompt_embeds[0]
args['negative_prompt_embeds_llama3'] = negative_prompt_embeds[1]
elif hasattr(model, 'text_encoder') and hasattr(model, 'tokenizer') and 'negative_prompt_embeds' in possible and prompt_parser_diffusers.embedder is not None:
args['negative_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_prompt_embeds')
if 'StableCascade' in model.__class__.__name__:
args['negative_prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('negative_pooleds').unsqueeze(0)
elif 'XL' in model.__class__.__name__:
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
elif 'StableDiffusion3' in model.__class__.__name__:
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
elif 'Chroma' in model.__class__.__name__:
args['negative_prompt_attention_mask'] = prompt_parser_diffusers.embedder('negative_prompt_attention_masks')
else:
if 'PixArtSigmaPipeline' in model.__class__.__name__: # pixart-sigma pipeline throws list-of-list for negative prompt
args['negative_prompt'] = negative_prompts[0]
else:
args['negative_prompt'] = negative_prompts
if 'complex_human_instruction' in possible:
chi = shared.opts.te_complex_human_instruction
p.extra_generation_params["CHI"] = chi
@ -297,14 +229,6 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
args['use_resolution_binning'] = False
if 'use_mask_in_transformer' in possible:
args['use_mask_in_transformer'] = shared.opts.te_use_mask
if prompt_parser_diffusers.embedder is not None and not prompt_parser_diffusers.embedder.scheduled_prompt: # not scheduled so we dont need it anymore
prompt_parser_diffusers.embedder = None
if 'clip_skip' in possible and parser == 'fixed':
if clip_skip == 1:
pass # clip_skip = None
else:
args['clip_skip'] = clip_skip - 1
timesteps = re.split(',| ', shared.opts.schedulers_timesteps)
if len(timesteps) > 2:
@ -491,7 +415,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
clean['negative_prompt'] = len(clean['negative_prompt'])
if generator is not None:
clean['generator'] = f'{generator[0].device}:{[g.initial_seed() for g in generator]}'
clean['parser'] = parser
clean['parser'] = prompt_attention
for k, v in clean.copy().items():
if v is None:
clean[k] = None

View File

@ -372,44 +372,6 @@ def resize_hires(p, latents): # input=latents output=pil if not latent_upscaler
return resized
def fix_prompts(p, prompts, negative_prompts, prompts_2, negative_prompts_2):
if hasattr(p, 'keep_prompts'):
return prompts, negative_prompts, prompts_2, negative_prompts_2
if type(prompts) is str:
prompts = [prompts]
if type(negative_prompts) is str:
negative_prompts = [negative_prompts]
if hasattr(p, '[init_images]') and p.init_images is not None and len(p.init_images) > 1:
while len(prompts) < len(p.init_images):
prompts.append(prompts[-1])
while len(negative_prompts) < len(p.init_images):
negative_prompts.append(negative_prompts[-1])
while len(prompts) < p.batch_size:
prompts.append(prompts[-1])
while len(negative_prompts) < p.batch_size:
negative_prompts.append(negative_prompts[-1])
while len(negative_prompts) < len(prompts):
negative_prompts.append(negative_prompts[-1])
while len(prompts) < len(negative_prompts):
prompts.append(prompts[-1])
if type(prompts_2) is str:
prompts_2 = [prompts_2]
if type(prompts_2) is list:
while len(prompts_2) < len(prompts):
prompts_2.append(prompts_2[-1])
if type(negative_prompts_2) is str:
negative_prompts_2 = [negative_prompts_2]
if type(negative_prompts_2) is list:
while len(negative_prompts_2) < len(prompts_2):
negative_prompts_2.append(negative_prompts_2[-1])
return prompts, negative_prompts, prompts_2, negative_prompts_2
def calculate_base_steps(p, use_denoise_start, use_refiner_start):
if len(getattr(p, 'timesteps', [])) > 0:
return None

View File

@ -0,0 +1,176 @@
import os
import torch
from modules import shared, errors, timer, prompt_parser_diffusers
debug_enabled = os.environ.get('SD_PROMPT_DEBUG', None) is not None
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
def fix_prompt_batch(p, prompts, negative_prompts, prompts_2, negative_prompts_2):
if hasattr(p, 'keep_prompts'):
return prompts, negative_prompts, prompts_2, negative_prompts_2
if type(prompts) is str:
prompts = [prompts]
if type(negative_prompts) is str:
negative_prompts = [negative_prompts]
if hasattr(p, '[init_images]') and p.init_images is not None and len(p.init_images) > 1:
while len(prompts) < len(p.init_images):
prompts.append(prompts[-1])
while len(negative_prompts) < len(p.init_images):
negative_prompts.append(negative_prompts[-1])
while len(prompts) < p.batch_size:
prompts.append(prompts[-1])
while len(negative_prompts) < p.batch_size:
negative_prompts.append(negative_prompts[-1])
while len(negative_prompts) < len(prompts):
negative_prompts.append(negative_prompts[-1])
while len(prompts) < len(negative_prompts):
prompts.append(prompts[-1])
if type(prompts_2) is str:
prompts_2 = [prompts_2]
if type(prompts_2) is list:
while len(prompts_2) < len(prompts):
prompts_2.append(prompts_2[-1])
if type(negative_prompts_2) is str:
negative_prompts_2 = [negative_prompts_2]
if type(negative_prompts_2) is list:
while len(negative_prompts_2) < len(prompts_2):
negative_prompts_2.append(negative_prompts_2[-1])
return prompts, negative_prompts, prompts_2, negative_prompts_2
def fix_prompt_model(cls, prompts, negative_prompts, prompts_2, negative_prompts_2):
if 'OmniGen' in cls:
prompts = [p.replace('|image|', '<img><|image_1|></img>') for p in prompts]
if 'PixArtSigmaPipeline' in cls: # pixart-sigma pipeline throws list-of-list for negative prompt
negative_prompts = negative_prompts[0]
return prompts, negative_prompts, prompts_2, negative_prompts_2
def set_fallback_prompt(args: dict, possible: list[str], prompts, negative_prompts, prompts_2, negative_prompts_2) -> dict:
if ('prompt' in possible) and ('prompt' not in args) and (prompts is not None) and len(prompts) > 0:
debug_log(f'Prompt fallback: prompt={prompts}')
args['prompt'] = prompts
if ('negative_prompt' in possible) and ('negative_prompt' not in args) and (negative_prompts is not None) and len(negative_prompts) > 0:
debug_log(f'Prompt fallback: negative_prompt={negative_prompts}')
args['negative_prompt'] = negative_prompts
if ('prompt_2' in possible) and ('prompt_2' not in args) and (prompts_2 is not None) and len(prompts_2) > 0:
debug_log(f'Prompt fallback: prompt_2={prompts_2}')
args['prompt_2'] = prompts_2
if ('negative_prompt_2' in possible) and ('negative_prompt_2' not in args) and (negative_prompts_2 is not None) and len(negative_prompts_2) > 0:
debug_log(f'Prompt fallback: negative_prompt_2={negative_prompts_2}')
args['negative_prompt_2'] = negative_prompts_2
return args
def set_prompt(p,
args: dict,
possible: list[str],
cls: str,
prompt_attention: str,
steps: int,
clip_skip: int,
prompts: list[str],
negative_prompts: list[str],
prompts_2: list[str],
negative_prompts_2: list[str],
) -> dict:
prompt_attention = prompt_attention or shared.opts.prompt_attention
if (prompt_attention != 'fixed') and ('Onnx' not in cls) and ('prompt' not in p.task_args) and (
('StableDiffusion' in cls) or
('StableCascade' in cls) or
('Flux' in cls and 'Flux2' not in cls) or
('Chroma' in cls) or
('HiDreamImagePipeline' in cls)
):
jobid = shared.state.begin('TE Encode')
try:
prompt_parser_diffusers.embedder = prompt_parser_diffusers.PromptEmbedder(prompts, negative_prompts, steps, clip_skip, p)
except Exception as e:
shared.log.error(f'Prompt parser encode: {e}')
if debug_enabled:
errors.display(e, 'Prompt parser encode')
prompt_parser_diffusers.embedder = None
timer.process.record('prompt', reset=False)
shared.state.end(jobid)
else:
prompt_parser_diffusers.embedder = None
prompt_attention = 'fixed'
prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompt_batch(p, prompts, negative_prompts, prompts_2, negative_prompts_2)
prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompt_model(cls, prompts, negative_prompts, prompts_2, negative_prompts_2)
args = set_fallback_prompt(args, possible, prompts=None, negative_prompts=None, prompts_2=prompts_2, negative_prompts_2=negative_prompts_2) # we dont parse secondary prompts
if prompt_parser_diffusers.embedder is not None:
if 'prompt' in possible:
debug_log(f'Prompt set embeds: positive={prompts}')
prompt_embeds = prompt_parser_diffusers.embedder('prompt_embeds')
prompt_pooled_embeds = prompt_parser_diffusers.embedder('positive_pooleds')
prompt_attention_masks = prompt_parser_diffusers.embedder('prompt_attention_masks')
if prompt_embeds is None:
shared.log.warning('Prompt parser encode: empty prompt embeds')
prompt_parser_diffusers.embedder = None
args = set_fallback_prompt(args, possible, prompts=prompts, negative_prompts=None, prompts_2=None, negative_prompts_2=None)
prompt_attention = 'fixed'
elif prompt_embeds.device == torch.device('meta'):
shared.log.warning('Prompt parser encode: embeds on meta device')
prompt_parser_diffusers.embedder = None
args = set_fallback_prompt(args, possible, prompts=prompts, negative_prompts=None, prompts_2=None, negative_prompts_2=None)
prompt_attention = 'fixed'
else:
if 'prompt_embeds' in possible:
args['prompt_embeds'] = prompt_embeds
if 'pooled_prompt_embeds' in possible:
args['pooled_prompt_embeds'] = prompt_pooled_embeds
if 'StableCascade' in cls:
args['prompt_embeds_pooled'] = prompt_pooled_embeds.unsqueeze(0)
if 'HiDreamImage' in cls:
args['prompt_embeds_t5'] = prompt_embeds[0]
args['prompt_embeds_llama3'] = prompt_embeds[1]
if 'prompt_attention_mask' in possible:
args['prompt_attention_mask'] = prompt_attention_masks
if 'negative_prompt' in possible:
debug_log(f'Prompt set embeds: negative={negative_prompts}')
negative_embeds = prompt_parser_diffusers.embedder('negative_prompt_embeds')
negative_pooled_embeds = prompt_parser_diffusers.embedder('negative_pooleds')
negative_attention_masks = prompt_parser_diffusers.embedder('negative_prompt_attention_masks')
if negative_embeds is None:
shared.log.warning('Prompt parser encode: empty negative prompt embeds')
prompt_parser_diffusers.embedder = None
args = set_fallback_prompt(args, possible, prompts=None, negative_prompts=negative_prompts, prompts_2=None, negative_prompts_2=None)
prompt_attention = 'fixed'
elif negative_embeds.device == torch.device('meta'):
shared.log.warning('Prompt parser encode: negative embeds on meta device')
prompt_parser_diffusers.embedder = None
args = set_fallback_prompt(args, possible, prompts=None, negative_prompts=negative_prompts, prompts_2=None, negative_prompts_2=None)
prompt_attention = 'fixed'
else:
if 'negative_prompt_embeds' in possible:
args['negative_prompt_embeds'] = negative_embeds
if 'negative_pooled_prompt_embeds' in possible:
args['negative_pooled_prompt_embeds'] = negative_pooled_embeds
if 'StableCascade' in cls:
args['negative_prompt_embeds_pooled'] = negative_pooled_embeds.unsqueeze(0)
if 'HiDreamImage' in cls:
args['negative_prompt_embeds_t5'] = negative_embeds[0]
args['negative_prompt_embeds_llama3'] = negative_embeds[1]
if 'negative_prompt_attention_mask' in possible:
args['negative_prompt_attention_mask'] = negative_attention_masks
else:
debug_log('Prompt fallback: no embedder')
args = set_fallback_prompt(args, possible, prompts=prompts, negative_prompts=negative_prompts, prompts_2=None, negative_prompts_2=None)
prompt_attention = 'fixed'
if (prompt_parser_diffusers.embedder is not None) and (not prompt_parser_diffusers.embedder.scheduled_prompt):
prompt_parser_diffusers.embedder = None # not scheduled so we dont need it anymore
return prompt_attention, args

View File

@ -48,7 +48,13 @@ def prepare_model(pipe = None):
class PromptEmbedder:
def __init__(self, prompts, negative_prompts, steps, clip_skip, p):
def __init__(self,
prompts,
negative_prompts,
steps,
clip_skip,
p,
):
t0 = time.time()
self.prompts = prompts
self.negative_prompts = negative_prompts