fix
parent
fe5d76e6b9
commit
82ab6f3b56
|
|
@ -28,7 +28,7 @@ def combine_denoised_hijack(
|
|||
aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
|
||||
|
||||
if prompt.local_transform is not None:
|
||||
cond_delta, weight = apply_masked_transform(cond_delta, prompt.local_transform)
|
||||
cond_delta = apply_affine_transform(cond_delta, prompt.local_transform)
|
||||
aux_cond_delta = apply_affine_transform(aux_cond_delta, prompt.local_transform)
|
||||
|
||||
cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
|
||||
|
|
@ -184,14 +184,14 @@ def get_cond_delta(prompt: neutral_prompt_parser.PromptExpr, args: CombineDenois
|
|||
weight = prompt.weight
|
||||
|
||||
if prompt.local_transform is not None:
|
||||
cond_delta, weight = apply_masked_transform(cond_delta, prompt.local_transform)
|
||||
weight = weight
|
||||
transformed_cond, weight = apply_masked_transform(cond_delta + args.uncond, prompt.local_transform, prompt.weight)
|
||||
cond_delta = transformed_cond - args.uncond
|
||||
|
||||
return cond_delta, weight
|
||||
|
||||
|
||||
def apply_masked_transform(tensor: torch.Tensor, affine: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
tensor_with_mask = torch.cat([tensor, create_cosine_feathered_mask(tensor.shape[-2:]).unsqueeze(0).to(tensor.device)])
|
||||
def apply_masked_transform(tensor: torch.Tensor, affine: torch.Tensor, weight: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
tensor_with_mask = torch.cat([tensor, create_cosine_feathered_mask(tensor.shape[-2:], weight).unsqueeze(0).to(tensor.device)])
|
||||
transformed = apply_affine_transform(tensor_with_mask, affine)
|
||||
return transformed[:-1], transformed[-1]
|
||||
|
||||
|
|
@ -203,11 +203,11 @@ def apply_affine_transform(tensor, affine):
|
|||
affine[1, 0] /= aspect_ratio
|
||||
|
||||
grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False)
|
||||
transformed_tensors = F.grid_sample(tensor.unsqueeze(0), grid, align_corners=False)
|
||||
transformed_tensors = F.grid_sample(tensor.unsqueeze(0), grid, 'nearest', align_corners=False)
|
||||
return transformed_tensors.squeeze(0)
|
||||
|
||||
|
||||
def create_cosine_feathered_mask(size):
|
||||
def create_cosine_feathered_mask(size, weight: float):
|
||||
"""
|
||||
Create a cosine-based feathered mask.
|
||||
"""
|
||||
|
|
@ -215,7 +215,7 @@ def create_cosine_feathered_mask(size):
|
|||
dist = torch.sqrt(x**2 + y**2)
|
||||
mask = 0.5 * (1 + torch.cos(torch.pi * dist))
|
||||
mask[dist > 1] = 0
|
||||
return mask.float()
|
||||
return mask.float() * weight
|
||||
|
||||
|
||||
def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
|||
|
|
@ -7,9 +7,6 @@ import torch
|
|||
import functools
|
||||
|
||||
|
||||
sampling_step = 0
|
||||
|
||||
|
||||
class NeutralPromptScript(scripts.Script):
|
||||
def __init__(self):
|
||||
self.accordion_interface = None
|
||||
|
|
@ -109,25 +106,14 @@ class CombinePreNoiseArgs:
|
|||
cond_indices: List[Tuple[int, float]]
|
||||
|
||||
|
||||
noises = []
|
||||
|
||||
|
||||
def on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams):
|
||||
if not global_state.is_enabled:
|
||||
return
|
||||
|
||||
global noises, sampling_step
|
||||
sampling_step += 1
|
||||
if sampling_step == 1:
|
||||
noises = params.x.clone()
|
||||
return
|
||||
|
||||
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, global_state.batch_cond_indices)):
|
||||
for prompt, cond_indices in zip(global_state.prompt_exprs, global_state.batch_cond_indices):
|
||||
args = CombinePreNoiseArgs(params.x, cond_indices)
|
||||
inv_transforms = prompt.accept(GlobalToLocalAffineVisitor(), args, 0)
|
||||
for cond_index, weight in cond_indices:
|
||||
# noisy_component = noises[cond_index] * torch.sum(noises[cond_index] * params.x[cond_index]) / torch.norm(noises[cond_index]) ** 2
|
||||
# params.x[cond_index] = apply_affine_transform(params.x[cond_index] - noisy_component, inv_transforms[cond_index]) + noisy_component
|
||||
for cond_index, _ in cond_indices:
|
||||
params.x[cond_index] = apply_affine_transform(params.x[cond_index], inv_transforms[cond_index])
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue