v0.6.0 (add token merging)
parent
2b55811f35
commit
6ce7f2a102
|
|
@ -10,6 +10,7 @@ Alpha version of a plugin for [automatic1111/stable-diffusion-webui](https://git
|
|||
|
||||
## Releases
|
||||
|
||||
- [29.08.2023] 🏎️ v0.6.0: Up to 2x faster and 4x less memory usage thanks to [Token Merging](https://github.com/dbolya/tomesd/tree/main)(tested with 16 feedback images and a batch size of 4), moderate gains for fewer feedback images (10% speedup for 2 images, 30% for 8 images). Make sure to enable the Token Merging option to take advantage of this.
|
||||
- [22.08.2023] 🗃️ v0.5.0: Adds support for presets; makes generated images using FABRIC more reproducible by loading the correct (previously used) feedback images when using "send to text2img/img2img".
|
||||
|
||||
## Installation
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from PIL import Image
|
|||
import modules.scripts
|
||||
from modules import script_callbacks
|
||||
from modules.ui_components import FormGroup, FormRow
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, get_fixed_seed
|
||||
|
||||
from scripts.helpers import WebUiComponents, image_hash
|
||||
from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass
|
||||
|
|
@ -27,7 +27,7 @@ except ImportError:
|
|||
from modules.ui import create_refresh_button
|
||||
|
||||
|
||||
__version__ = "0.5.3"
|
||||
__version__ = "0.6.0"
|
||||
|
||||
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")
|
||||
|
||||
|
|
@ -135,6 +135,10 @@ class FabricParams:
|
|||
neg_latent_cache: Optional[dict] = None
|
||||
|
||||
feedback_during_high_res_fix: bool = False
|
||||
tome_enabled: bool = False
|
||||
tome_ratio: float = 0.5
|
||||
tome_max_tokens: int = 4*4096
|
||||
tome_seed: int = -1
|
||||
|
||||
|
||||
# TODO: replace global state with Gradio state
|
||||
|
|
@ -208,17 +212,24 @@ class FabricScript(modules.scripts.Script):
|
|||
gr.HTML("<hr style='border-color: var(--block-border-color)'>")
|
||||
|
||||
with gr.Column():
|
||||
with FormRow():
|
||||
feedback_max_images = gr.Slider(minimum=0, maximum=10, step=1, value=4, label="Max. feedback images")
|
||||
|
||||
with FormRow():
|
||||
feedback_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label="Feedback start")
|
||||
feedback_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.8, label="Feedback end")
|
||||
|
||||
with FormRow():
|
||||
feedback_min_weight = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.0, label="Min. weight")
|
||||
feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.5, value=0.8, label="Max. weight")
|
||||
feedback_neg_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Neg. scale")
|
||||
feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight")
|
||||
with FormRow():
|
||||
tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False)
|
||||
|
||||
with gr.Accordion("Advanced options", open=DEBUG):
|
||||
with FormGroup():
|
||||
feedback_min_weight = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.0, label="Min. strength", info="Minimum feedback strength at every diffusion step.", elem_id="fabric_min_weight")
|
||||
feedback_neg_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Negative weight", info="Strength of negative feedback relative to positive feedback.", elem_id="fabric_neg_scale")
|
||||
|
||||
with FormGroup():
|
||||
tome_ratio = gr.Slider(minimum=0.0, maximum=0.75, step=0.125, value=0.5, label="ToMe merge ratio", info="Percentage of tokens to be merged (higher improves speed)", elem_id="fabric_tome_ratio")
|
||||
tome_max_tokens = gr.Slider(minimum=4096, maximum=16*4096, step=4096, value=2*4096, label="ToMe max. tokens", info="Maximum number of tokens after merging (lower improves VRAM usage)", elem_id="fabric_tome_max_tokens")
|
||||
tome_seed = gr.Number(label="ToMe seed", value=-1, info="Random seed for ToMe partition", elem_id="fabric_tome_seed")
|
||||
|
||||
|
||||
|
||||
WebUiComponents.on_txt2img_gallery(self.register_txt2img_gallery_select)
|
||||
|
|
@ -289,6 +300,10 @@ class FabricScript(modules.scripts.Script):
|
|||
(feedback_min_weight, "fabric_min_weight"),
|
||||
(feedback_max_weight, "fabric_max_weight"),
|
||||
(feedback_neg_scale, "fabric_neg_scale"),
|
||||
(tome_enabled, "fabric_tome_enabled"),
|
||||
(tome_ratio, "fabric_tome_ratio"),
|
||||
(tome_max_tokens, "fabric_tome_max_tokens"),
|
||||
(tome_seed, "fabric_tome_seed"),
|
||||
(feedback_during_high_res_fix, "fabric_feedback_during_high_res_fix"),
|
||||
(liked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_pos_images")) if "fabric_pos_images" in d else None),
|
||||
(disliked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_neg_images")) if "fabric_neg_images" in d else None),
|
||||
|
|
@ -300,13 +315,16 @@ class FabricScript(modules.scripts.Script):
|
|||
liked_paths,
|
||||
disliked_paths,
|
||||
feedback_enabled,
|
||||
feedback_max_images,
|
||||
feedback_start,
|
||||
feedback_end,
|
||||
feedback_min_weight,
|
||||
feedback_max_weight,
|
||||
feedback_neg_scale,
|
||||
feedback_during_high_res_fix,
|
||||
tome_enabled,
|
||||
tome_ratio,
|
||||
tome_max_tokens,
|
||||
tome_seed,
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -403,13 +421,16 @@ class FabricScript(modules.scripts.Script):
|
|||
liked_paths,
|
||||
disliked_paths,
|
||||
feedback_enabled,
|
||||
feedback_max_images,
|
||||
feedback_start,
|
||||
feedback_end,
|
||||
feedback_min_weight,
|
||||
feedback_max_weight,
|
||||
feedback_neg_scale,
|
||||
feedback_during_high_res_fix,
|
||||
tome_enabled,
|
||||
tome_ratio,
|
||||
tome_max_tokens,
|
||||
tome_seed,
|
||||
) = args
|
||||
|
||||
# restore original U-Net forward pass in case previous batch errored out
|
||||
|
|
@ -418,9 +439,6 @@ class FabricScript(modules.scripts.Script):
|
|||
if not feedback_enabled:
|
||||
return
|
||||
|
||||
liked_paths = liked_paths[-int(feedback_max_images):]
|
||||
disliked_paths = disliked_paths[-int(feedback_max_images):]
|
||||
|
||||
likes = [load_feedback_image(path) for path in liked_paths]
|
||||
dislikes = [load_feedback_image(path) for path in disliked_paths]
|
||||
|
||||
|
|
@ -434,6 +452,10 @@ class FabricScript(modules.scripts.Script):
|
|||
pos_images=likes,
|
||||
neg_images=dislikes,
|
||||
feedback_during_high_res_fix=feedback_during_high_res_fix,
|
||||
tome_enabled=tome_enabled,
|
||||
tome_ratio=(round(tome_ratio * 16) / 16),
|
||||
tome_max_tokens=tome_max_tokens,
|
||||
tome_seed=get_fixed_seed(tome_seed),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,191 @@
|
|||
import torch
|
||||
import math
|
||||
from typing import Dict, Any, Tuple, Callable
|
||||
|
||||
|
||||
"""
|
||||
Copied and adapted from https://github.com/dbolya/tomesd/tree/main
|
||||
Relevant files:
|
||||
- https://github.com/dbolya/tomesd/blob/main/tomesd/merge.py
|
||||
- https://github.com/dbolya/tomesd/blob/main/tomesd/patching.py
|
||||
"""
|
||||
|
||||
def init_generator(device: torch.device, fallback: torch.Generator=None, seed: int = 42):
|
||||
"""
|
||||
Forks the current default random generator given device.
|
||||
"""
|
||||
if device.type == "cpu":
|
||||
return torch.Generator(device="cpu").manual_seed(seed)
|
||||
elif device.type == "cuda":
|
||||
return torch.Generator(device=device).manual_seed(seed)
|
||||
else:
|
||||
if fallback is None:
|
||||
return init_generator(torch.device("cpu"))
|
||||
else:
|
||||
return fallback
|
||||
|
||||
def do_nothing(x: torch.Tensor, mode: str = None):
|
||||
return x
|
||||
|
||||
|
||||
def mps_gather_workaround(input, dim, index):
|
||||
if input.shape[-1] == 1:
|
||||
return torch.gather(
|
||||
input.unsqueeze(-1),
|
||||
dim - 1 if dim < 0 else dim,
|
||||
index.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
else:
|
||||
return torch.gather(input, dim, index)
|
||||
|
||||
|
||||
def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
w: int, h: int, sx: int, sy: int, r: int,
|
||||
no_rand: bool = False,
|
||||
generator: torch.Generator = None) -> Tuple[Callable, Callable]:
|
||||
"""
|
||||
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||
|
||||
Args:
|
||||
- metric [B, N, C]: metric to use for similarity
|
||||
- w: image width in tokens
|
||||
- h: image height in tokens
|
||||
- sx: stride in the x dimension for dst, must divide w
|
||||
- sy: stride in the y dimension for dst, must divide h
|
||||
- r: number of tokens to remove (by merging)
|
||||
- no_rand: if true, disable randomness (use top left corner only)
|
||||
- rand_seed: if no_rand is false, and if not None, sets random seed.
|
||||
"""
|
||||
B, N, _ = metric.shape
|
||||
|
||||
if r <= 0:
|
||||
return do_nothing, do_nothing
|
||||
|
||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||
|
||||
with torch.no_grad():
|
||||
hsy, wsx = h // sy, w // sx
|
||||
|
||||
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||
if no_rand:
|
||||
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||
else:
|
||||
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device)
|
||||
|
||||
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
|
||||
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
|
||||
|
||||
# Image is not divisible by sx or sy so we need to move it into a new buffer
|
||||
if (hsy * sy) < h or (wsx * sx) < w:
|
||||
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
|
||||
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
|
||||
else:
|
||||
idx_buffer = idx_buffer_view
|
||||
|
||||
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
|
||||
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
|
||||
|
||||
# We're finished with these
|
||||
del idx_buffer, idx_buffer_view
|
||||
|
||||
# rand_idx is currently dst|src, so split them
|
||||
num_dst = hsy * wsx
|
||||
a_idx = rand_idx[:, num_dst:, :] # src
|
||||
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||
|
||||
def split(x):
|
||||
C = x.shape[-1]
|
||||
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
||||
return src, dst
|
||||
|
||||
# Cosine similarity between A and B
|
||||
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||
a, b = split(metric)
|
||||
scores = a @ b.transpose(-1, -2)
|
||||
|
||||
# Can't reduce more than the # tokens in src
|
||||
r = min(a.shape[1], r)
|
||||
|
||||
# Find the most similar greedily
|
||||
node_max, node_idx = scores.max(dim=-1)
|
||||
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||
|
||||
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
||||
|
||||
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||
src, dst = split(x)
|
||||
n, t1, c = src.shape
|
||||
|
||||
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||
|
||||
return torch.cat([unm, dst], dim=1)
|
||||
|
||||
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
||||
unm_len = unm_idx.shape[1]
|
||||
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||
_, _, c = unm.shape
|
||||
|
||||
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
|
||||
|
||||
# Combine back to the original shape
|
||||
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||
|
||||
return out
|
||||
|
||||
return merge, unmerge
|
||||
|
||||
|
||||
def compute_merge(
|
||||
x: torch.Tensor,
|
||||
args: Dict[str, Any],
|
||||
size: Tuple[int, int],
|
||||
max_tokens: int = None,
|
||||
ratio: float = None,
|
||||
) -> Tuple[Callable, ...]:
|
||||
if not args["enabled"]:
|
||||
return do_nothing, do_nothing
|
||||
|
||||
if max_tokens is None and ratio is None:
|
||||
raise ValueError("Must specify either max_tokens or ratio")
|
||||
|
||||
original_h, original_w = size
|
||||
original_tokens = original_h * original_w
|
||||
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
||||
|
||||
if ratio is not None:
|
||||
target_tokens = int(x.shape[1] * (1 - ratio))
|
||||
else:
|
||||
target_tokens = x.shape[1]
|
||||
|
||||
if max_tokens is not None and max_tokens > 0:
|
||||
target_tokens = min(target_tokens, max_tokens) # remove all but max_tokens tokens
|
||||
r = x.shape[1] - target_tokens
|
||||
|
||||
if r > 0:
|
||||
w = int(math.ceil(original_w / downsample))
|
||||
h = int(math.ceil(original_h / downsample))
|
||||
|
||||
# Re-init the generator if it hasn't already been initialized or device has changed.
|
||||
if args["generator"] is None:
|
||||
args["generator"] = init_generator(x.device, seed=args["seed"])
|
||||
elif args["generator"].device != x.device:
|
||||
args["generator"] = init_generator(x.device, fallback=args["generator"], seed=args["seed"])
|
||||
|
||||
# If the batch size is odd, then it's not possible for prompted and unprompted images to be in the same
|
||||
# batch, which causes artifacts with use_rand, so force it to be off.
|
||||
use_rand = False if x.shape[0] % 2 == 1 else args["use_rand"]
|
||||
return bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r,
|
||||
no_rand=not use_rand, generator=args["generator"])
|
||||
else:
|
||||
return do_nothing, do_nothing
|
||||
|
|
@ -11,6 +11,7 @@ from ldm.modules.attention import BasicTransformerBlock
|
|||
from scripts.marking import apply_marking_patch, unmark_prompt_context
|
||||
from scripts.helpers import image_hash
|
||||
from scripts.weighted_attention import weighted_attention
|
||||
from scripts.merging import compute_merge
|
||||
|
||||
|
||||
def encode_to_latent(p, image, w, h):
|
||||
|
|
@ -89,6 +90,14 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
hr_w = width
|
||||
hr_h = height
|
||||
|
||||
tome_args = {
|
||||
"enabled": params.tome_enabled,
|
||||
"sx": 2, "sy": 2,
|
||||
"use_rand": True,
|
||||
"generator": None,
|
||||
"seed": params.tome_seed,
|
||||
}
|
||||
|
||||
def new_forward(self, x, timesteps=None, context=None, **kwargs):
|
||||
_, uncond_ids, context = unmark_prompt_context(context)
|
||||
cond_ids = [i for i in range(context.size(0)) if i not in uncond_ids]
|
||||
|
|
@ -133,81 +142,101 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"):
|
||||
module.attn1._fabric_old_forward = module.attn1.forward
|
||||
module.attn2._fabric_old_forward = module.attn2.forward
|
||||
|
||||
## cache hidden states
|
||||
try:
|
||||
## cache hidden states
|
||||
cached_hiddens = {}
|
||||
def patched_attn1_forward(attn1, layer_idx, x, **kwargs):
|
||||
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
|
||||
x = merge(x)
|
||||
if layer_idx not in cached_hiddens:
|
||||
cached_hiddens[layer_idx] = x.detach().clone().cpu()
|
||||
else:
|
||||
cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0)
|
||||
out = attn1._fabric_old_forward(x, **kwargs)
|
||||
out = unmerge(out)
|
||||
return out
|
||||
|
||||
def patched_attn2_forward(attn2, x, **kwargs):
|
||||
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
|
||||
x = merge(x)
|
||||
out = attn2._fabric_old_forward(x, **kwargs)
|
||||
out = unmerge(out)
|
||||
return out
|
||||
|
||||
cached_hiddens = {}
|
||||
def patched_attn1_forward(attn1, idx, x, **kwargs):
|
||||
if idx not in cached_hiddens:
|
||||
cached_hiddens[idx] = x.detach().clone().cpu()
|
||||
else:
|
||||
cached_hiddens[idx] = torch.cat([cached_hiddens[idx], x.detach().clone().cpu()], dim=0)
|
||||
out = attn1._fabric_old_forward(x, **kwargs)
|
||||
return out
|
||||
# patch forward pass to cache hidden states
|
||||
layer_idx = 0
|
||||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock):
|
||||
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
|
||||
module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2)
|
||||
layer_idx += 1
|
||||
|
||||
# patch forward pass to cache hidden states
|
||||
layer_idx = 0
|
||||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock):
|
||||
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
|
||||
layer_idx += 1
|
||||
# run forward pass just to cache hidden states, output is discarded
|
||||
for i in range(0, len(all_zs), batch_size):
|
||||
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
|
||||
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
|
||||
# use the null prompt for pre-computing hidden states on feedback images
|
||||
ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim)
|
||||
_ = self._fabric_old_forward(zs, ts, ctx)
|
||||
|
||||
# run forward pass just to cache hidden states, output is discarded
|
||||
for i in range(0, len(all_zs), batch_size):
|
||||
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
|
||||
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
|
||||
# use the null prompt for pre-computing hidden states on feedback images
|
||||
ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim)
|
||||
_ = self._fabric_old_forward(zs, ts, ctx)
|
||||
num_pos = len(pos_latents)
|
||||
num_neg = len(neg_latents)
|
||||
num_cond = len(cond_ids)
|
||||
num_uncond = len(uncond_ids)
|
||||
tome_h_latent = h_latent * (1 - params.tome_ratio)
|
||||
|
||||
num_pos = len(pos_latents)
|
||||
num_neg = len(neg_latents)
|
||||
num_cond = len(cond_ids)
|
||||
num_uncond = len(uncond_ids)
|
||||
def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):
|
||||
if context is None:
|
||||
context = x
|
||||
|
||||
def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):
|
||||
if context is None:
|
||||
context = x
|
||||
cached_hs = cached_hiddens[idx].to(x.device)
|
||||
|
||||
cached_hs = cached_hiddens[idx].to(x.device)
|
||||
d_model = x.shape[-1]
|
||||
|
||||
seq_len, d_model = x.shape[1:]
|
||||
def attention_with_feedback(_x, context, feedback_hs, w):
|
||||
num_xs, num_fb = _x.shape[0], feedback_hs.shape[0]
|
||||
if num_fb > 0:
|
||||
feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim)
|
||||
merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens)
|
||||
feedback_ctx = merge(feedback_ctx)
|
||||
ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim)
|
||||
else:
|
||||
ctx = context
|
||||
weights = torch.ones_like(ctx[0, :, 0]) # (seq + seq*n_pos,)
|
||||
weights[_x.shape[1]:] = w
|
||||
return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim)
|
||||
|
||||
outs = []
|
||||
if num_cond > 0:
|
||||
pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(num_cond, -1, -1) # (n_cond, seq * n_pos, dim)
|
||||
x_cond = x[cond_ids] # (n_cond, seq, dim)
|
||||
ctx_cond = torch.cat([context[cond_ids], pos_hs], dim=1) # (n_cond, seq * (1 + n_pos), dim)
|
||||
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_pos),)
|
||||
ws[x_cond.size(1):] = pos_weight
|
||||
out_cond = weighted_attention(attn1, attn1._fabric_old_forward, x_cond, ctx_cond, ws, **kwargs) # (n_cond, seq, dim)
|
||||
outs.append(out_cond)
|
||||
if num_uncond > 0:
|
||||
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(num_uncond, -1, -1) # (n_uncond, seq * n_neg, dim)
|
||||
x_uncond = x[uncond_ids] # (n_uncond, seq, dim)
|
||||
ctx_uncond = torch.cat([context[uncond_ids], neg_hs], dim=1) # (n_uncond, seq * (1 + n_neg), dim)
|
||||
ws = torch.ones_like(ctx_uncond[0, :, 0]) # (seq * (1 + n_neg),)
|
||||
ws[x_uncond.size(1):] = neg_weight
|
||||
out_uncond = weighted_attention(attn1, attn1._fabric_old_forward, x_uncond, ctx_uncond, ws, **kwargs) # (n_uncond, seq, dim)
|
||||
outs.append(out_uncond)
|
||||
out = torch.cat(outs, dim=0)
|
||||
return out
|
||||
outs = []
|
||||
if num_cond > 0:
|
||||
out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim)
|
||||
outs.append(out_cond)
|
||||
if num_uncond > 0:
|
||||
out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim)
|
||||
outs.append(out_uncond)
|
||||
out = torch.cat(outs, dim=0)
|
||||
return out
|
||||
|
||||
# patch forward pass to inject cached hidden states
|
||||
layer_idx = 0
|
||||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock):
|
||||
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
|
||||
layer_idx += 1
|
||||
# patch forward pass to inject cached hidden states
|
||||
layer_idx = 0
|
||||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock):
|
||||
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
|
||||
layer_idx += 1
|
||||
|
||||
# run forward pass with cached hidden states
|
||||
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
# run forward pass with cached hidden states
|
||||
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
|
||||
# restore original pass
|
||||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"):
|
||||
module.attn1.forward = module.attn1._fabric_old_forward
|
||||
del module.attn1._fabric_old_forward
|
||||
finally:
|
||||
# restore original pass
|
||||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"):
|
||||
module.attn1.forward = module.attn1._fabric_old_forward
|
||||
del module.attn1._fabric_old_forward
|
||||
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn2, "_fabric_old_forward"):
|
||||
module.attn2.forward = module.attn2._fabric_old_forward
|
||||
del module.attn2._fabric_old_forward
|
||||
|
||||
return out
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue