From 82ab6f3b56613eef48de84a3e4e8ed4cad20b3ee Mon Sep 17 00:00:00 2001 From: ljleb Date: Thu, 1 Feb 2024 07:20:33 -0500 Subject: [PATCH] fix --- lib_neutral_prompt/cfg_denoiser_hijack.py | 16 ++++++++-------- scripts/neutral_prompt.py | 18 ++---------------- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/lib_neutral_prompt/cfg_denoiser_hijack.py b/lib_neutral_prompt/cfg_denoiser_hijack.py index f541c96..b7702d2 100644 --- a/lib_neutral_prompt/cfg_denoiser_hijack.py +++ b/lib_neutral_prompt/cfg_denoiser_hijack.py @@ -28,7 +28,7 @@ def combine_denoised_hijack( aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0) if prompt.local_transform is not None: - cond_delta, weight = apply_masked_transform(cond_delta, prompt.local_transform) + cond_delta = apply_affine_transform(cond_delta, prompt.local_transform) aux_cond_delta = apply_affine_transform(aux_cond_delta, prompt.local_transform) cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale @@ -184,14 +184,14 @@ def get_cond_delta(prompt: neutral_prompt_parser.PromptExpr, args: CombineDenois weight = prompt.weight if prompt.local_transform is not None: - cond_delta, weight = apply_masked_transform(cond_delta, prompt.local_transform) - weight = weight + transformed_cond, weight = apply_masked_transform(cond_delta + args.uncond, prompt.local_transform, prompt.weight) + cond_delta = transformed_cond - args.uncond return cond_delta, weight -def apply_masked_transform(tensor: torch.Tensor, affine: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - tensor_with_mask = torch.cat([tensor, create_cosine_feathered_mask(tensor.shape[-2:]).unsqueeze(0).to(tensor.device)]) +def apply_masked_transform(tensor: torch.Tensor, affine: torch.Tensor, weight: float) -> Tuple[torch.Tensor, torch.Tensor]: + tensor_with_mask = torch.cat([tensor, create_cosine_feathered_mask(tensor.shape[-2:], weight).unsqueeze(0).to(tensor.device)]) transformed = apply_affine_transform(tensor_with_mask, affine) return transformed[:-1], transformed[-1] @@ -203,11 +203,11 @@ def apply_affine_transform(tensor, affine): affine[1, 0] /= aspect_ratio grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False) - transformed_tensors = F.grid_sample(tensor.unsqueeze(0), grid, align_corners=False) + transformed_tensors = F.grid_sample(tensor.unsqueeze(0), grid, 'nearest', align_corners=False) return transformed_tensors.squeeze(0) -def create_cosine_feathered_mask(size): +def create_cosine_feathered_mask(size, weight: float): """ Create a cosine-based feathered mask. """ @@ -215,7 +215,7 @@ def create_cosine_feathered_mask(size): dist = torch.sqrt(x**2 + y**2) mask = 0.5 * (1 + torch.cos(torch.pi * dist)) mask[dist > 1] = 0 - return mask.float() + return mask.float() * weight def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor: diff --git a/scripts/neutral_prompt.py b/scripts/neutral_prompt.py index 515407d..6fcdc87 100644 --- a/scripts/neutral_prompt.py +++ b/scripts/neutral_prompt.py @@ -7,9 +7,6 @@ import torch import functools -sampling_step = 0 - - class NeutralPromptScript(scripts.Script): def __init__(self): self.accordion_interface = None @@ -109,25 +106,14 @@ class CombinePreNoiseArgs: cond_indices: List[Tuple[int, float]] -noises = [] - - def on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams): if not global_state.is_enabled: return - global noises, sampling_step - sampling_step += 1 - if sampling_step == 1: - noises = params.x.clone() - return - - for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, global_state.batch_cond_indices)): + for prompt, cond_indices in zip(global_state.prompt_exprs, global_state.batch_cond_indices): args = CombinePreNoiseArgs(params.x, cond_indices) inv_transforms = prompt.accept(GlobalToLocalAffineVisitor(), args, 0) - for cond_index, weight in cond_indices: - # noisy_component = noises[cond_index] * torch.sum(noises[cond_index] * params.x[cond_index]) / torch.norm(noises[cond_index]) ** 2 - # params.x[cond_index] = apply_affine_transform(params.x[cond_index] - noisy_component, inv_transforms[cond_index]) + noisy_component + for cond_index, _ in cond_indices: params.x[cond_index] = apply_affine_transform(params.x[cond_index], inv_transforms[cond_index])