v0.6.0 (add token merging)

pull/38/head
dvruette 2023-08-29 00:43:14 +02:00
parent 2b55811f35
commit 6ce7f2a102
4 changed files with 320 additions and 77 deletions

View File

@ -10,6 +10,7 @@ Alpha version of a plugin for [automatic1111/stable-diffusion-webui](https://git
## Releases
- [29.08.2023] 🏎️ v0.6.0: Up to 2x faster and 4x less memory usage thanks to [Token Merging](https://github.com/dbolya/tomesd/tree/main)(tested with 16 feedback images and a batch size of 4), moderate gains for fewer feedback images (10% speedup for 2 images, 30% for 8 images). Make sure to enable the Token Merging option to take advantage of this.
- [22.08.2023] 🗃️ v0.5.0: Adds support for presets; makes generated images using FABRIC more reproducible by loading the correct (previously used) feedback images when using "send to text2img/img2img".
## Installation

View File

@ -13,7 +13,7 @@ from PIL import Image
import modules.scripts
from modules import script_callbacks
from modules.ui_components import FormGroup, FormRow
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, get_fixed_seed
from scripts.helpers import WebUiComponents, image_hash
from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass
@ -27,7 +27,7 @@ except ImportError:
from modules.ui import create_refresh_button
__version__ = "0.5.3"
__version__ = "0.6.0"
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")
@ -135,6 +135,10 @@ class FabricParams:
neg_latent_cache: Optional[dict] = None
feedback_during_high_res_fix: bool = False
tome_enabled: bool = False
tome_ratio: float = 0.5
tome_max_tokens: int = 4*4096
tome_seed: int = -1
# TODO: replace global state with Gradio state
@ -208,17 +212,24 @@ class FabricScript(modules.scripts.Script):
gr.HTML("<hr style='border-color: var(--block-border-color)'>")
with gr.Column():
with FormRow():
feedback_max_images = gr.Slider(minimum=0, maximum=10, step=1, value=4, label="Max. feedback images")
with FormRow():
feedback_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label="Feedback start")
feedback_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.8, label="Feedback end")
with FormRow():
feedback_min_weight = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.0, label="Min. weight")
feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.5, value=0.8, label="Max. weight")
feedback_neg_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Neg. scale")
feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight")
with FormRow():
tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False)
with gr.Accordion("Advanced options", open=DEBUG):
with FormGroup():
feedback_min_weight = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.0, label="Min. strength", info="Minimum feedback strength at every diffusion step.", elem_id="fabric_min_weight")
feedback_neg_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Negative weight", info="Strength of negative feedback relative to positive feedback.", elem_id="fabric_neg_scale")
with FormGroup():
tome_ratio = gr.Slider(minimum=0.0, maximum=0.75, step=0.125, value=0.5, label="ToMe merge ratio", info="Percentage of tokens to be merged (higher improves speed)", elem_id="fabric_tome_ratio")
tome_max_tokens = gr.Slider(minimum=4096, maximum=16*4096, step=4096, value=2*4096, label="ToMe max. tokens", info="Maximum number of tokens after merging (lower improves VRAM usage)", elem_id="fabric_tome_max_tokens")
tome_seed = gr.Number(label="ToMe seed", value=-1, info="Random seed for ToMe partition", elem_id="fabric_tome_seed")
WebUiComponents.on_txt2img_gallery(self.register_txt2img_gallery_select)
@ -289,6 +300,10 @@ class FabricScript(modules.scripts.Script):
(feedback_min_weight, "fabric_min_weight"),
(feedback_max_weight, "fabric_max_weight"),
(feedback_neg_scale, "fabric_neg_scale"),
(tome_enabled, "fabric_tome_enabled"),
(tome_ratio, "fabric_tome_ratio"),
(tome_max_tokens, "fabric_tome_max_tokens"),
(tome_seed, "fabric_tome_seed"),
(feedback_during_high_res_fix, "fabric_feedback_during_high_res_fix"),
(liked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_pos_images")) if "fabric_pos_images" in d else None),
(disliked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_neg_images")) if "fabric_neg_images" in d else None),
@ -300,13 +315,16 @@ class FabricScript(modules.scripts.Script):
liked_paths,
disliked_paths,
feedback_enabled,
feedback_max_images,
feedback_start,
feedback_end,
feedback_min_weight,
feedback_max_weight,
feedback_neg_scale,
feedback_during_high_res_fix,
tome_enabled,
tome_ratio,
tome_max_tokens,
tome_seed,
]
@ -403,13 +421,16 @@ class FabricScript(modules.scripts.Script):
liked_paths,
disliked_paths,
feedback_enabled,
feedback_max_images,
feedback_start,
feedback_end,
feedback_min_weight,
feedback_max_weight,
feedback_neg_scale,
feedback_during_high_res_fix,
tome_enabled,
tome_ratio,
tome_max_tokens,
tome_seed,
) = args
# restore original U-Net forward pass in case previous batch errored out
@ -418,9 +439,6 @@ class FabricScript(modules.scripts.Script):
if not feedback_enabled:
return
liked_paths = liked_paths[-int(feedback_max_images):]
disliked_paths = disliked_paths[-int(feedback_max_images):]
likes = [load_feedback_image(path) for path in liked_paths]
dislikes = [load_feedback_image(path) for path in disliked_paths]
@ -434,6 +452,10 @@ class FabricScript(modules.scripts.Script):
pos_images=likes,
neg_images=dislikes,
feedback_during_high_res_fix=feedback_during_high_res_fix,
tome_enabled=tome_enabled,
tome_ratio=(round(tome_ratio * 16) / 16),
tome_max_tokens=tome_max_tokens,
tome_seed=get_fixed_seed(tome_seed),
)

191
scripts/merging.py Normal file
View File

@ -0,0 +1,191 @@
import torch
import math
from typing import Dict, Any, Tuple, Callable
"""
Copied and adapted from https://github.com/dbolya/tomesd/tree/main
Relevant files:
- https://github.com/dbolya/tomesd/blob/main/tomesd/merge.py
- https://github.com/dbolya/tomesd/blob/main/tomesd/patching.py
"""
def init_generator(device: torch.device, fallback: torch.Generator=None, seed: int = 42):
"""
Forks the current default random generator given device.
"""
if device.type == "cpu":
return torch.Generator(device="cpu").manual_seed(seed)
elif device.type == "cuda":
return torch.Generator(device=device).manual_seed(seed)
else:
if fallback is None:
return init_generator(torch.device("cpu"))
else:
return fallback
def do_nothing(x: torch.Tensor, mode: str = None):
return x
def mps_gather_workaround(input, dim, index):
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1),
dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)
).squeeze(-1)
else:
return torch.gather(input, dim, index)
def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False,
generator: torch.Generator = None) -> Tuple[Callable, Callable]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
- h: image height in tokens
- sx: stride in the x dimension for dst, must divide w
- sy: stride in the y dimension for dst, must divide h
- r: number of tokens to remove (by merging)
- no_rand: if true, disable randomness (use top left corner only)
- rand_seed: if no_rand is false, and if not None, sets random seed.
"""
B, N, _ = metric.shape
if r <= 0:
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad():
hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
if no_rand:
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device)
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
# Image is not divisible by sx or sy so we need to move it into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst
def split(x):
C = x.shape[-1]
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)
# Can't reduce more than the # tokens in src
r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1)
def unmerge(x: torch.Tensor) -> torch.Tensor:
unm_len = unm_idx.shape[1]
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
return out
return merge, unmerge
def compute_merge(
x: torch.Tensor,
args: Dict[str, Any],
size: Tuple[int, int],
max_tokens: int = None,
ratio: float = None,
) -> Tuple[Callable, ...]:
if not args["enabled"]:
return do_nothing, do_nothing
if max_tokens is None and ratio is None:
raise ValueError("Must specify either max_tokens or ratio")
original_h, original_w = size
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
if ratio is not None:
target_tokens = int(x.shape[1] * (1 - ratio))
else:
target_tokens = x.shape[1]
if max_tokens is not None and max_tokens > 0:
target_tokens = min(target_tokens, max_tokens) # remove all but max_tokens tokens
r = x.shape[1] - target_tokens
if r > 0:
w = int(math.ceil(original_w / downsample))
h = int(math.ceil(original_h / downsample))
# Re-init the generator if it hasn't already been initialized or device has changed.
if args["generator"] is None:
args["generator"] = init_generator(x.device, seed=args["seed"])
elif args["generator"].device != x.device:
args["generator"] = init_generator(x.device, fallback=args["generator"], seed=args["seed"])
# If the batch size is odd, then it's not possible for prompted and unprompted images to be in the same
# batch, which causes artifacts with use_rand, so force it to be off.
use_rand = False if x.shape[0] % 2 == 1 else args["use_rand"]
return bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r,
no_rand=not use_rand, generator=args["generator"])
else:
return do_nothing, do_nothing

View File

@ -11,6 +11,7 @@ from ldm.modules.attention import BasicTransformerBlock
from scripts.marking import apply_marking_patch, unmark_prompt_context
from scripts.helpers import image_hash
from scripts.weighted_attention import weighted_attention
from scripts.merging import compute_merge
def encode_to_latent(p, image, w, h):
@ -89,6 +90,14 @@ def patch_unet_forward_pass(p, unet, params):
hr_w = width
hr_h = height
tome_args = {
"enabled": params.tome_enabled,
"sx": 2, "sy": 2,
"use_rand": True,
"generator": None,
"seed": params.tome_seed,
}
def new_forward(self, x, timesteps=None, context=None, **kwargs):
_, uncond_ids, context = unmark_prompt_context(context)
cond_ids = [i for i in range(context.size(0)) if i not in uncond_ids]
@ -133,81 +142,101 @@ def patch_unet_forward_pass(p, unet, params):
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"):
module.attn1._fabric_old_forward = module.attn1.forward
module.attn2._fabric_old_forward = module.attn2.forward
## cache hidden states
try:
## cache hidden states
cached_hiddens = {}
def patched_attn1_forward(attn1, layer_idx, x, **kwargs):
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
x = merge(x)
if layer_idx not in cached_hiddens:
cached_hiddens[layer_idx] = x.detach().clone().cpu()
else:
cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0)
out = attn1._fabric_old_forward(x, **kwargs)
out = unmerge(out)
return out
def patched_attn2_forward(attn2, x, **kwargs):
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
x = merge(x)
out = attn2._fabric_old_forward(x, **kwargs)
out = unmerge(out)
return out
cached_hiddens = {}
def patched_attn1_forward(attn1, idx, x, **kwargs):
if idx not in cached_hiddens:
cached_hiddens[idx] = x.detach().clone().cpu()
else:
cached_hiddens[idx] = torch.cat([cached_hiddens[idx], x.detach().clone().cpu()], dim=0)
out = attn1._fabric_old_forward(x, **kwargs)
return out
# patch forward pass to cache hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2)
layer_idx += 1
# patch forward pass to cache hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
layer_idx += 1
# run forward pass just to cache hidden states, output is discarded
for i in range(0, len(all_zs), batch_size):
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
# use the null prompt for pre-computing hidden states on feedback images
ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim)
_ = self._fabric_old_forward(zs, ts, ctx)
# run forward pass just to cache hidden states, output is discarded
for i in range(0, len(all_zs), batch_size):
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
# use the null prompt for pre-computing hidden states on feedback images
ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim)
_ = self._fabric_old_forward(zs, ts, ctx)
num_pos = len(pos_latents)
num_neg = len(neg_latents)
num_cond = len(cond_ids)
num_uncond = len(uncond_ids)
tome_h_latent = h_latent * (1 - params.tome_ratio)
num_pos = len(pos_latents)
num_neg = len(neg_latents)
num_cond = len(cond_ids)
num_uncond = len(uncond_ids)
def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):
if context is None:
context = x
def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):
if context is None:
context = x
cached_hs = cached_hiddens[idx].to(x.device)
cached_hs = cached_hiddens[idx].to(x.device)
d_model = x.shape[-1]
seq_len, d_model = x.shape[1:]
def attention_with_feedback(_x, context, feedback_hs, w):
num_xs, num_fb = _x.shape[0], feedback_hs.shape[0]
if num_fb > 0:
feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim)
merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens)
feedback_ctx = merge(feedback_ctx)
ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim)
else:
ctx = context
weights = torch.ones_like(ctx[0, :, 0]) # (seq + seq*n_pos,)
weights[_x.shape[1]:] = w
return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim)
outs = []
if num_cond > 0:
pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(num_cond, -1, -1) # (n_cond, seq * n_pos, dim)
x_cond = x[cond_ids] # (n_cond, seq, dim)
ctx_cond = torch.cat([context[cond_ids], pos_hs], dim=1) # (n_cond, seq * (1 + n_pos), dim)
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_pos),)
ws[x_cond.size(1):] = pos_weight
out_cond = weighted_attention(attn1, attn1._fabric_old_forward, x_cond, ctx_cond, ws, **kwargs) # (n_cond, seq, dim)
outs.append(out_cond)
if num_uncond > 0:
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(num_uncond, -1, -1) # (n_uncond, seq * n_neg, dim)
x_uncond = x[uncond_ids] # (n_uncond, seq, dim)
ctx_uncond = torch.cat([context[uncond_ids], neg_hs], dim=1) # (n_uncond, seq * (1 + n_neg), dim)
ws = torch.ones_like(ctx_uncond[0, :, 0]) # (seq * (1 + n_neg),)
ws[x_uncond.size(1):] = neg_weight
out_uncond = weighted_attention(attn1, attn1._fabric_old_forward, x_uncond, ctx_uncond, ws, **kwargs) # (n_uncond, seq, dim)
outs.append(out_uncond)
out = torch.cat(outs, dim=0)
return out
outs = []
if num_cond > 0:
out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim)
outs.append(out_cond)
if num_uncond > 0:
out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim)
outs.append(out_uncond)
out = torch.cat(outs, dim=0)
return out
# patch forward pass to inject cached hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
layer_idx += 1
# patch forward pass to inject cached hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
layer_idx += 1
# run forward pass with cached hidden states
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
# run forward pass with cached hidden states
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
# restore original pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"):
module.attn1.forward = module.attn1._fabric_old_forward
del module.attn1._fabric_old_forward
finally:
# restore original pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"):
module.attn1.forward = module.attn1._fabric_old_forward
del module.attn1._fabric_old_forward
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn2, "_fabric_old_forward"):
module.attn2.forward = module.attn2._fabric_old_forward
del module.attn2._fabric_old_forward
return out