affine
ljleb 2024-02-01 07:20:33 -05:00
parent fe5d76e6b9
commit 82ab6f3b56
2 changed files with 10 additions and 24 deletions

View File

@ -28,7 +28,7 @@ def combine_denoised_hijack(
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
if prompt.local_transform is not None:
cond_delta, weight = apply_masked_transform(cond_delta, prompt.local_transform)
cond_delta = apply_affine_transform(cond_delta, prompt.local_transform)
aux_cond_delta = apply_affine_transform(aux_cond_delta, prompt.local_transform)
cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
@ -184,14 +184,14 @@ def get_cond_delta(prompt: neutral_prompt_parser.PromptExpr, args: CombineDenois
weight = prompt.weight
if prompt.local_transform is not None:
cond_delta, weight = apply_masked_transform(cond_delta, prompt.local_transform)
weight = weight
transformed_cond, weight = apply_masked_transform(cond_delta + args.uncond, prompt.local_transform, prompt.weight)
cond_delta = transformed_cond - args.uncond
return cond_delta, weight
def apply_masked_transform(tensor: torch.Tensor, affine: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
tensor_with_mask = torch.cat([tensor, create_cosine_feathered_mask(tensor.shape[-2:]).unsqueeze(0).to(tensor.device)])
def apply_masked_transform(tensor: torch.Tensor, affine: torch.Tensor, weight: float) -> Tuple[torch.Tensor, torch.Tensor]:
tensor_with_mask = torch.cat([tensor, create_cosine_feathered_mask(tensor.shape[-2:], weight).unsqueeze(0).to(tensor.device)])
transformed = apply_affine_transform(tensor_with_mask, affine)
return transformed[:-1], transformed[-1]
@ -203,11 +203,11 @@ def apply_affine_transform(tensor, affine):
affine[1, 0] /= aspect_ratio
grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False)
transformed_tensors = F.grid_sample(tensor.unsqueeze(0), grid, align_corners=False)
transformed_tensors = F.grid_sample(tensor.unsqueeze(0), grid, 'nearest', align_corners=False)
return transformed_tensors.squeeze(0)
def create_cosine_feathered_mask(size):
def create_cosine_feathered_mask(size, weight: float):
"""
Create a cosine-based feathered mask.
"""
@ -215,7 +215,7 @@ def create_cosine_feathered_mask(size):
dist = torch.sqrt(x**2 + y**2)
mask = 0.5 * (1 + torch.cos(torch.pi * dist))
mask[dist > 1] = 0
return mask.float()
return mask.float() * weight
def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:

View File

@ -7,9 +7,6 @@ import torch
import functools
sampling_step = 0
class NeutralPromptScript(scripts.Script):
def __init__(self):
self.accordion_interface = None
@ -109,25 +106,14 @@ class CombinePreNoiseArgs:
cond_indices: List[Tuple[int, float]]
noises = []
def on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams):
if not global_state.is_enabled:
return
global noises, sampling_step
sampling_step += 1
if sampling_step == 1:
noises = params.x.clone()
return
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, global_state.batch_cond_indices)):
for prompt, cond_indices in zip(global_state.prompt_exprs, global_state.batch_cond_indices):
args = CombinePreNoiseArgs(params.x, cond_indices)
inv_transforms = prompt.accept(GlobalToLocalAffineVisitor(), args, 0)
for cond_index, weight in cond_indices:
# noisy_component = noises[cond_index] * torch.sum(noises[cond_index] * params.x[cond_index]) / torch.norm(noises[cond_index]) ** 2
# params.x[cond_index] = apply_affine_transform(params.x[cond_index] - noisy_component, inv_transforms[cond_index]) + noisy_component
for cond_index, _ in cond_indices:
params.x[cond_index] = apply_affine_transform(params.x[cond_index], inv_transforms[cond_index])