302 lines
11 KiB
Python
302 lines
11 KiB
Python
from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
|
|
from modules import script_callbacks, sd_samplers, shared
|
|
from typing import Tuple, List, Union
|
|
import dataclasses
|
|
import functools
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import sys
|
|
import textwrap
|
|
|
|
|
|
def combine_denoised_hijack(
|
|
x_out: torch.Tensor,
|
|
batch_cond_indices: List[List[Tuple[int, float]]],
|
|
text_uncond: torch.Tensor,
|
|
cond_scale: float,
|
|
original_function,
|
|
) -> torch.Tensor:
|
|
if not global_state.is_enabled:
|
|
return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
|
|
|
|
denoised = get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function)
|
|
uncond = x_out[-text_uncond.shape[0]:]
|
|
|
|
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
|
|
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
|
cond_delta = prompt.accept(CondDeltaVisitor(), args, 0)
|
|
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
|
|
|
|
if prompt.local_transform is not None:
|
|
cond_delta = apply_affine_transform(cond_delta, prompt.local_transform)
|
|
aux_cond_delta = apply_affine_transform(aux_cond_delta, prompt.local_transform)
|
|
|
|
cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
|
|
denoised[batch_i] = cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
|
|
|
|
return denoised
|
|
|
|
|
|
def get_webui_denoised(
|
|
x_out: torch.Tensor,
|
|
batch_cond_indices: List[List[Tuple[int, float]]],
|
|
text_uncond: torch.Tensor,
|
|
cond_scale: float,
|
|
original_function,
|
|
):
|
|
uncond = x_out[-text_uncond.shape[0]:]
|
|
sliced_batch_x_out = []
|
|
sliced_batch_cond_indices = []
|
|
|
|
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
|
|
args = CombineDenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
|
sliced_x_out, sliced_cond_indices = gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out))
|
|
if sliced_cond_indices:
|
|
sliced_batch_cond_indices.append(sliced_cond_indices)
|
|
sliced_batch_x_out.extend(sliced_x_out)
|
|
|
|
sliced_batch_x_out += list(uncond)
|
|
sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0)
|
|
return original_function(sliced_batch_x_out, sliced_batch_cond_indices, text_uncond, cond_scale)
|
|
|
|
|
|
def cfg_rescale(cfg_cond, cond):
|
|
if global_state.cfg_rescale == 0:
|
|
return cfg_cond
|
|
|
|
global_state.apply_and_clear_cfg_rescale_override()
|
|
cfg_cond_mean = cfg_cond.mean()
|
|
cfg_rescale_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean()
|
|
cfg_rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_cond.std() - 1) + 1
|
|
return cfg_rescale_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CombineDenoiseArgs:
|
|
x_out: torch.Tensor
|
|
uncond: torch.Tensor
|
|
cond_indices: List[Tuple[int, float]]
|
|
|
|
|
|
def gather_webui_conds(
|
|
prompt: neutral_prompt_parser.CompositePrompt,
|
|
args: CombineDenoiseArgs,
|
|
index_in: int,
|
|
index_out: int,
|
|
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
|
|
sliced_x_out = []
|
|
sliced_cond_indices = []
|
|
|
|
for child in prompt.children:
|
|
if child.conciliation is None:
|
|
if isinstance(child, neutral_prompt_parser.LeafPrompt) and child.local_transform is None:
|
|
child_x_out = args.x_out[index_in]
|
|
child_weight = child.weight
|
|
else:
|
|
child_x_out, child_weight = get_cond_delta(child, args, index_in)
|
|
child_x_out += args.uncond
|
|
index_offset = index_out + len(sliced_x_out)
|
|
sliced_x_out.append(child_x_out)
|
|
sliced_cond_indices.append((index_offset, child_weight))
|
|
|
|
index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
|
|
|
return sliced_x_out, sliced_cond_indices
|
|
|
|
|
|
class CondDeltaVisitor:
|
|
def visit_leaf_prompt(
|
|
self,
|
|
that: neutral_prompt_parser.LeafPrompt,
|
|
args: CombineDenoiseArgs,
|
|
index: int,
|
|
) -> torch.Tensor:
|
|
cond_info = args.cond_indices[index]
|
|
if that.weight != cond_info[1]:
|
|
console_warn(f'''
|
|
An unexpected noise weight was encountered at prompt #{index}
|
|
Expected :{that.weight}, but got :{cond_info[1]}
|
|
This is likely due to another extension also monkey patching the webui `combine_denoised` function
|
|
Please open a bug report here so that the conflict can be resolved:
|
|
https://github.com/ljleb/sd-webui-neutral-prompt/issues
|
|
''')
|
|
|
|
return args.x_out[cond_info[0]] - args.uncond
|
|
|
|
def visit_composite_prompt(
|
|
self,
|
|
that: neutral_prompt_parser.CompositePrompt,
|
|
args: CombineDenoiseArgs,
|
|
index: int,
|
|
) -> torch.Tensor:
|
|
cond_delta = torch.zeros_like(args.x_out[0])
|
|
|
|
for child in that.children:
|
|
if child.conciliation is None:
|
|
child_cond_delta, child_weight = get_cond_delta(child, args, index)
|
|
cond_delta += child_weight * child_cond_delta
|
|
|
|
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
|
|
|
return cond_delta
|
|
|
|
|
|
class AuxCondDeltaVisitor:
|
|
def visit_leaf_prompt(
|
|
self,
|
|
that: neutral_prompt_parser.LeafPrompt,
|
|
args: CombineDenoiseArgs,
|
|
cond_delta: torch.Tensor,
|
|
index: int,
|
|
) -> torch.Tensor:
|
|
return torch.zeros_like(args.x_out[0])
|
|
|
|
def visit_composite_prompt(
|
|
self,
|
|
that: neutral_prompt_parser.CompositePrompt,
|
|
args: CombineDenoiseArgs,
|
|
cond_delta: torch.Tensor,
|
|
index: int,
|
|
) -> torch.Tensor:
|
|
aux_cond_delta = torch.zeros_like(args.x_out[0])
|
|
salient_cond_deltas = []
|
|
|
|
for child in that.children:
|
|
if child.conciliation is not None:
|
|
child_cond_delta, child_weight = get_cond_delta(child, args, index)
|
|
|
|
if child.conciliation == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR:
|
|
aux_cond_delta += child_weight * get_perpendicular_component(cond_delta, child_cond_delta)
|
|
elif child.conciliation == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK:
|
|
salient_cond_deltas.append((child_cond_delta, child_weight))
|
|
elif child.conciliation == neutral_prompt_parser.ConciliationStrategy.SEMANTIC_GUIDANCE:
|
|
aux_cond_delta += child_weight * filter_abs_top_k(child_cond_delta, 0.05)
|
|
|
|
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
|
|
|
aux_cond_delta += salient_blend(cond_delta, salient_cond_deltas)
|
|
return aux_cond_delta
|
|
|
|
|
|
def get_cond_delta(prompt: neutral_prompt_parser.PromptExpr, args: CombineDenoiseArgs, index: int) -> Tuple[torch.Tensor, Union[float, torch.Tensor]]:
|
|
cond_delta = prompt.accept(CondDeltaVisitor(), args, index)
|
|
cond_delta += prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, index)
|
|
weight = prompt.weight
|
|
|
|
if prompt.local_transform is not None:
|
|
transformed_cond, weight = apply_masked_transform(cond_delta + args.uncond, prompt.local_transform, prompt.weight)
|
|
cond_delta = transformed_cond - args.uncond
|
|
|
|
return cond_delta, weight
|
|
|
|
|
|
def apply_masked_transform(tensor: torch.Tensor, affine: torch.Tensor, weight: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
tensor_with_mask = torch.cat([tensor, create_cosine_feathered_mask(tensor.shape[-2:], weight).unsqueeze(0).to(tensor.device)])
|
|
transformed = apply_affine_transform(tensor_with_mask, affine)
|
|
return transformed[:-1], transformed[-1]
|
|
|
|
|
|
def apply_affine_transform(tensor, affine):
|
|
affine = affine.clone().to(tensor.device)
|
|
aspect_ratio = tensor.shape[-2] / tensor.shape[-1]
|
|
affine[0, 1] *= aspect_ratio
|
|
affine[1, 0] /= aspect_ratio
|
|
|
|
grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False)
|
|
transformed_tensors = F.grid_sample(tensor.unsqueeze(0), grid, 'nearest', align_corners=False)
|
|
return transformed_tensors.squeeze(0)
|
|
|
|
|
|
def create_cosine_feathered_mask(size, weight: float):
|
|
"""
|
|
Create a cosine-based feathered mask.
|
|
"""
|
|
y, x = torch.meshgrid(torch.linspace(-1, 1, size[0]), torch.linspace(-1, 1, size[1]))
|
|
dist = torch.sqrt(x**2 + y**2)
|
|
mask = 0.5 * (1 + torch.cos(torch.pi * dist))
|
|
mask[dist > 1] = 0
|
|
return mask.float() * weight
|
|
|
|
|
|
def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
|
|
if (normal == 0).all():
|
|
if shared.state.sampling_step <= 0:
|
|
warn_projection_not_found()
|
|
|
|
return vector
|
|
|
|
return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2
|
|
|
|
|
|
def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]]) -> torch.Tensor:
|
|
"""
|
|
Blends the `normal` tensor with `vectors` in salient regions, weighting contributions by their weights.
|
|
Salience maps are calculated to identify regions of interest.
|
|
The blended result combines `normal` and vector information in salient regions.
|
|
"""
|
|
|
|
salience_maps = [get_salience(normal)] + [get_salience(vector) for vector, _ in vectors]
|
|
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
|
|
|
|
result = torch.zeros_like(normal)
|
|
for mask_i, (vector, weight) in enumerate(vectors, start=1):
|
|
vector_mask = (mask == mask_i).float()
|
|
result += weight * vector_mask * (vector - normal)
|
|
|
|
return result
|
|
|
|
|
|
def get_salience(vector: torch.Tensor) -> torch.Tensor:
|
|
return torch.softmax(torch.abs(vector).flatten(), dim=0).reshape_as(vector)
|
|
|
|
|
|
def filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
|
|
k = int(torch.numel(vector) * (1 - k_ratio))
|
|
top_k, _ = torch.kthvalue(torch.abs(torch.flatten(vector)), k)
|
|
return vector * (torch.abs(vector) >= top_k).to(vector.dtype)
|
|
|
|
|
|
sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
|
|
module=sd_samplers,
|
|
hijacker_attribute='__neutral_prompt_hijacker',
|
|
on_uninstall=script_callbacks.on_script_unloaded,
|
|
)
|
|
|
|
|
|
@sd_samplers_hijacker.hijack('create_sampler')
|
|
def create_sampler_hijack(name: str, model, original_function):
|
|
sampler = original_function(name, model)
|
|
if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
|
|
if global_state.is_enabled:
|
|
warn_unsupported_sampler()
|
|
|
|
return sampler
|
|
|
|
sampler.model_wrap_cfg.combine_denoised = functools.partial(
|
|
combine_denoised_hijack,
|
|
original_function=sampler.model_wrap_cfg.combine_denoised
|
|
)
|
|
return sampler
|
|
|
|
|
|
def warn_unsupported_sampler():
|
|
console_warn('''
|
|
Neutral prompt relies on composition via AND, which the webui does not support when using any of the DDIM, PLMS and UniPC samplers
|
|
The sampler will NOT be patched
|
|
Falling back on original sampler implementation...
|
|
''')
|
|
|
|
|
|
def warn_projection_not_found():
|
|
console_warn('''
|
|
Could not find a projection for one or more AND_PERP prompts
|
|
These prompts will NOT be made perpendicular
|
|
''')
|
|
|
|
|
|
def console_warn(message):
|
|
if not global_state.verbose:
|
|
return
|
|
|
|
print(f'\n[sd-webui-neutral-prompt extension]{textwrap.dedent(message)}', file=sys.stderr)
|