sd-webui-neutral-prompt/lib_neutral_prompt/cfg_denoiser.py

105 lines
4.1 KiB
Python

from lib_neutral_prompt import hijacker, global_state, prompt_parser
from modules import script_callbacks, sd_samplers
import functools
import torch
import sys
import textwrap
def combine_denoised_hijack(x_out, batch_cond_indices, noisy_uncond, cond_scale, original_function):
if not global_state.is_enabled:
return original_function(x_out, batch_cond_indices, noisy_uncond, cond_scale)
denoised = get_original_denoised(original_function, x_out, batch_cond_indices, noisy_uncond, cond_scale)
uncond = x_out[-noisy_uncond.shape[0]:]
for batch_i, (combination_keywords, cond_indices) in enumerate(zip(global_state.perp_profile, batch_cond_indices)):
def get_cond_indices(filter_k: prompt_parser.PromptKeywords):
return [cond_indices[i] for i, k in enumerate(combination_keywords) if k == filter_k]
cond_delta = combine_cond_deltas(x_out, uncond[batch_i], get_cond_indices(prompt_parser.PromptKeywords.AND))
perp_cond_delta = combine_perp_cond_deltas(x_out, cond_delta, uncond[batch_i], get_cond_indices(prompt_parser.PromptKeywords.AND_PERP))
cfg_cond = denoised[batch_i] - perp_cond_delta * cond_scale
denoised[batch_i] = cfg_cond * get_cfg_rescale_factor(cfg_cond, uncond[batch_i] + cond_delta - perp_cond_delta)
return denoised
def get_original_denoised(original_function, x_out, batch_cond_indices, noisy_uncond, cond_scale):
sliced_batch_cond_indices = list([list(cond_indices) for cond_indices in batch_cond_indices])
seen_indices = set()
for batch_i, (combination_keywords, cond_indices) in reversed(list(enumerate(zip(global_state.perp_profile, batch_cond_indices)))):
for keyword, (cond_index, weight) in reversed(list(zip(combination_keywords, cond_indices))):
seen_indices.add(cond_index)
if keyword == prompt_parser.PromptKeywords.AND or keyword in seen_indices:
continue
x_out = torch.cat([x_out[:cond_index], x_out[cond_index + 1:]], dim=0)
del sliced_batch_cond_indices[batch_i][cond_index]
return original_function(x_out, sliced_batch_cond_indices, noisy_uncond, cond_scale)
def combine_cond_deltas(x_out, uncond, cond_indices):
cond_delta = torch.zeros_like(x_out[0])
for cond_index, weight in cond_indices:
cond_delta += weight * (x_out[cond_index] - uncond)
return cond_delta
def combine_perp_cond_deltas(x_out, cond_delta, uncond, cond_indices):
perp_cond_delta = torch.zeros_like(x_out[0])
for cond_index, weight in cond_indices:
perp_cond = x_out[cond_index]
perp_cond_delta -= weight * get_perpendicular_component(cond_delta, perp_cond - uncond)
return perp_cond_delta
def get_perpendicular_component(normal, vector):
assert vector.shape == normal.shape
return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2
def get_cfg_rescale_factor(denoised, positive_epsilon):
x_pos_std = torch.std(positive_epsilon)
x_cfg_std = torch.std(denoised)
return global_state.cfg_rescale * (x_pos_std / x_cfg_std - 1) + 1
sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
module=sd_samplers,
hijacker_attribute='__neutral_prompt_hijacker',
on_uninstall=script_callbacks.on_script_unloaded,
)
@sd_samplers_hijacker.hijack('create_sampler')
def create_sampler_hijack(name, model, original_function):
sampler = original_function(name, model)
if name in ('DDIM', 'PLMS', 'UniPC'):
if global_state.is_enabled:
warn_unsupported_sampler()
return sampler
sampler.model_wrap_cfg.combine_denoised = functools.partial(
combine_denoised_hijack,
original_function=sampler.model_wrap_cfg.combine_denoised
)
return sampler
def warn_unsupported_sampler():
if not global_state.verbose:
return
print(textwrap.dedent('''
[sd-webui-neutral-prompt extension]
Neutral prompt relies on composition via AND, which the webui does not support when using any of the DDIM, PLMS and UniPC samplers
The sampler will NOT be patched
Falling back on original sampler implementation...
'''), file=sys.stderr)