test and fix attention processors

pull/23/head
dvruette 2023-07-25 00:53:53 +02:00
parent 0290e51ab0
commit 5de45e9c9d
4 changed files with 54 additions and 33 deletions

View File

@ -17,8 +17,10 @@ Alpha version of a plugin for [automatic1111/stable-diffusion-webui](https://git
### Compatibility Notes
- SDXL is currently not supported (PRs welcome!)
- Compatibility to other plugins is largely untested. If you experience errors with other plugins enabled, please disable all other plugins for the best chance for FABRIC to work. If you can figure out which plugin is incompatible, please open an issue.
- Compatibility with other plugins is largely untested. If you experience errors with other plugins enabled, please disable all other plugins for the best chance for FABRIC to work. If you can figure out which plugin is incompatible, please open an issue.
- The plugin is INCOMPATIBLE with `reference` mode in the ControlNet plugin. Instead of using a reference image, simply add it as a liked image. If you accidentally enable FABRIC and `reference` mode at the same time, you will have to restart the WebUI to fix it.
- Some attention processors are not supported. In particular, `--opt-sub-quad-attention` and `--opt-split-attention-v1` are not supported at the moment.
## How-to and Examples

View File

@ -2,7 +2,7 @@ import os
import dataclasses
import functools
from pathlib import Path
from dataclasses import dataclass
from dataclasses import dataclass, asdict
import gradio as gr
from PIL import Image
@ -17,7 +17,7 @@ from scripts.helpers import WebUiComponents
__version__ = "0.3.5"
DEBUG = os.getenv("DEBUG", False)
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")
if DEBUG:
print(f"WARNING: Loading FABRIC v{__version__} in DEBUG mode")
@ -46,6 +46,13 @@ def use_feedback(params):
return False
return True
def pil_to_str(img):
if hasattr(img, "filename"):
return img.filename
else:
return f"{img.__class__.__name__}(size={img.size}, format={img.format})"
@dataclass
class FabricParams:
enabled: bool = True
@ -121,7 +128,6 @@ class FabricScript(modules.scripts.Script):
with FormGroup():
with gr.Row():
# TODO: figure out how to make the step size do what it's supposed to
feedback_max_images = gr.Slider(minimum=0, maximum=10, step=1, value=4, label="Max. feedback images")
with gr.Row():
@ -252,8 +258,8 @@ class FabricScript(modules.scripts.Script):
feedback_during_high_res_fix,
) = args
likes = liked_images[:int(feedback_max_images)]
dislikes = disliked_images[:int(feedback_max_images)]
likes = liked_images[-int(feedback_max_images):]
dislikes = disliked_images[-int(feedback_max_images):]
params = FabricParams(
enabled=(not feedback_disabled),
@ -267,8 +273,18 @@ class FabricScript(modules.scripts.Script):
feedback_during_high_res_fix=feedback_during_high_res_fix,
)
if use_feedback(params) or (DEBUG and not feedback_disabled):
print("[FABRIC] Patching U-Net forward pass...")
# log the generation params to be displayed/stored as metadata
log_params = asdict(params)
del log_params["enabled"]
log_params["pos_images"] = [pil_to_str(img) for img in log_params["pos_images"]]
log_params["neg_images"] = [pil_to_str(img) for img in log_params["neg_images"]]
log_params = {f"fabric/{k}": v for k, v in log_params.items()}
p.extra_generation_params.update(log_params)
unet = p.sd_model.model.diffusion_model
patch_unet_forward_pass(p, unet, params)
else:

View File

@ -46,14 +46,18 @@ def get_latents_from_params(p, params, width, height):
return params.pos_latents, params.neg_latents
def get_curr_feedback_weight(p, params):
w = params.max_weight
return w, w * params.neg_scale
def get_curr_feedback_weight(p, params, timestep):
progress = 1 - (timestep / (p.sd_model.num_timesteps - 1))
if progress >= params.start and progress <= params.end:
w = params.max_weight
else:
w = params.min_weight
return max(0, w), max(0, w * params.neg_scale)
def patch_unet_forward_pass(p, unet, params):
if not params.pos_images and not params.neg_images:
print("[FABRIC] No images to found, aborting patching")
print("[FABRIC] No feedback images found, aborting patching")
return
if not hasattr(unet, "_fabric_old_forward"):
@ -74,8 +78,8 @@ def patch_unet_forward_pass(p, unet, params):
hr_w = (hr_w // 8) * 8
hr_h = (hr_h // 8) * 8
else:
hr_h = width
hr_w = height
hr_w = width
hr_h = height
def new_forward(self, x, timesteps=None, context=None, **kwargs):
_, uncond_ids, context = unmark_prompt_context(context)
@ -83,11 +87,16 @@ def patch_unet_forward_pass(p, unet, params):
has_cond = len(cond_ids) > 0
has_uncond = len(uncond_ids) > 0
w_latent, h_latent = x.shape[-2:]
h_latent, w_latent = x.shape[-2:]
w, h = 8 * w_latent, 8 * h_latent
if has_hires_fix and w == hr_w and h == hr_h:
if not params.feedback_during_high_res_fix:
print("[FABRIC] Skipping feedback during high-res fix")
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item())
if pos_weight <= 0 and neg_weight <= 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_latents, neg_latents = get_latents_from_params(p, params, w, h)
pos_latents = pos_latents if has_cond else []
@ -148,7 +157,6 @@ def patch_unet_forward_pass(p, unet, params):
num_cond = len(cond_ids)
num_uncond = len(uncond_ids)
pos_weight, neg_weight = get_curr_feedback_weight(p, params)
outs = []
if num_cond > 0:
@ -157,15 +165,15 @@ def patch_unet_forward_pass(p, unet, params):
ctx_cond = torch.cat([context[cond_ids], pos_hs], dim=1) # (n_cond, seq * (1 + n_pos), dim)
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_pos),)
ws[x_cond.size(1):] = pos_weight
out_cond = weighted_attention(attn1._fabric_old_forward, x_cond, ctx_cond, ws, **kwargs) # (n_cond, seq, dim)
out_cond = weighted_attention(attn1, attn1._fabric_old_forward, x_cond, ctx_cond, ws, **kwargs) # (n_cond, seq, dim)
outs.append(out_cond)
if num_uncond > 0:
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(num_uncond, -1, -1) # (n_uncond, seq * n_neg, dim)
x_uncond = x[uncond_ids] # (n_uncond, seq, dim)
ctx_uncond = torch.cat([context[uncond_ids], neg_hs], dim=1) # (n_uncond, seq * (1 + n_neg), dim)
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_neg),)
ws[x_cond.size(1):] = neg_weight
out_uncond = weighted_attention(attn1._fabric_old_forward, x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim)
ws = torch.ones_like(ctx_uncond[0, :, 0]) # (seq * (1 + n_neg),)
ws[x_uncond.size(1):] = neg_weight
out_uncond = weighted_attention(attn1, attn1._fabric_old_forward, x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim)
outs.append(out_uncond)
out = torch.cat(outs, dim=0)
return out

View File

@ -35,10 +35,11 @@ def patched_einsum_op_compvis(q, k, v, weights=None):
def patched_xformers_attn(q, k, v, attn_bias=None, op=None, weights=None, orig_attn=None):
print(q.shape, v.shape, weights.shape)
bs, nq, nh, dh = q.shape # batch_size, num_queries, num_heads, dim_per_head
if weights is not None:
min_val = torch.finfo(q.dtype).min
w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(*q.shape[:3], -1).transpose(-2, -1)
w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(bs, nh, nq, -1).contiguous()
w_bias = w_bias.to(dtype=q.dtype)
if attn_bias is None:
attn_bias = w_bias
else:
@ -64,10 +65,6 @@ def patched_sdp_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, we
def weighted_split_cross_attention_forward(self, x, context=None, mask=None, weights=None):
h = self.heads
# OURS: normalize weights to preserve attention magnitude
if weights is not None:
weights = weights[None, None, :] / weights.sum(dim=-1, keepdim=True)
q_in = self.to_q(x)
context = default(context, x)
@ -112,12 +109,13 @@ def weighted_split_cross_attention_forward(self, x, context=None, mask=None, wei
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
# OURS: apply weights to attention
if weights is not None:
s2 = s2 * weights
bias = weights.to(s1.dtype).log().clamp(min=torch.finfo(s1.dtype).min)
s1 = s1 + bias
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
@ -138,12 +136,9 @@ def is_the_same(fn1, fn2):
return fn1.__name__ == fn2.__name__ and fn1.__module__ == fn2.__module__
def weighted_attention(attn_fn, x, context=None, weights=None, **kwargs):
def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs):
if weights is None:
return attn_fn(x, context=context, **kwargs)
print(attn_fn.__module__, attn_fn.__name__, type(attn_fn))
print(split_cross_attention_forward_invokeAI.__module__, split_cross_attention_forward_invokeAI.__name__, type(split_cross_attention_forward_invokeAI))
if is_the_same(attn_fn, split_cross_attention_forward_invokeAI):
modules.sd_hijack_optimizations.einsum_op_compvis = functools.partial(patched_einsum_op_compvis, weights=weights)
@ -167,7 +162,7 @@ def weighted_attention(attn_fn, x, context=None, weights=None, **kwargs):
return out
elif is_the_same(attn_fn, split_cross_attention_forward):
return weighted_split_cross_attention_forward(x, context=context, weights=weights, **kwargs)
return weighted_split_cross_attention_forward(self, x, context=context, weights=weights, **kwargs)
else:
raise NotImplementedError(f"FABRIC does not support `{attn_fn.__module__}.{attn_fn.__name__}` yet. Please choose a supported attention function.")