test and fix attention processors
parent
0290e51ab0
commit
5de45e9c9d
|
|
@ -17,8 +17,10 @@ Alpha version of a plugin for [automatic1111/stable-diffusion-webui](https://git
|
|||
|
||||
### Compatibility Notes
|
||||
- SDXL is currently not supported (PRs welcome!)
|
||||
- Compatibility to other plugins is largely untested. If you experience errors with other plugins enabled, please disable all other plugins for the best chance for FABRIC to work. If you can figure out which plugin is incompatible, please open an issue.
|
||||
- Compatibility with other plugins is largely untested. If you experience errors with other plugins enabled, please disable all other plugins for the best chance for FABRIC to work. If you can figure out which plugin is incompatible, please open an issue.
|
||||
- The plugin is INCOMPATIBLE with `reference` mode in the ControlNet plugin. Instead of using a reference image, simply add it as a liked image. If you accidentally enable FABRIC and `reference` mode at the same time, you will have to restart the WebUI to fix it.
|
||||
- Some attention processors are not supported. In particular, `--opt-sub-quad-attention` and `--opt-split-attention-v1` are not supported at the moment.
|
||||
|
||||
|
||||
|
||||
## How-to and Examples
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import dataclasses
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
|
|
@ -17,7 +17,7 @@ from scripts.helpers import WebUiComponents
|
|||
|
||||
__version__ = "0.3.5"
|
||||
|
||||
DEBUG = os.getenv("DEBUG", False)
|
||||
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")
|
||||
|
||||
if DEBUG:
|
||||
print(f"WARNING: Loading FABRIC v{__version__} in DEBUG mode")
|
||||
|
|
@ -46,6 +46,13 @@ def use_feedback(params):
|
|||
return False
|
||||
return True
|
||||
|
||||
|
||||
def pil_to_str(img):
|
||||
if hasattr(img, "filename"):
|
||||
return img.filename
|
||||
else:
|
||||
return f"{img.__class__.__name__}(size={img.size}, format={img.format})"
|
||||
|
||||
@dataclass
|
||||
class FabricParams:
|
||||
enabled: bool = True
|
||||
|
|
@ -121,7 +128,6 @@ class FabricScript(modules.scripts.Script):
|
|||
|
||||
with FormGroup():
|
||||
with gr.Row():
|
||||
# TODO: figure out how to make the step size do what it's supposed to
|
||||
feedback_max_images = gr.Slider(minimum=0, maximum=10, step=1, value=4, label="Max. feedback images")
|
||||
|
||||
with gr.Row():
|
||||
|
|
@ -252,8 +258,8 @@ class FabricScript(modules.scripts.Script):
|
|||
feedback_during_high_res_fix,
|
||||
) = args
|
||||
|
||||
likes = liked_images[:int(feedback_max_images)]
|
||||
dislikes = disliked_images[:int(feedback_max_images)]
|
||||
likes = liked_images[-int(feedback_max_images):]
|
||||
dislikes = disliked_images[-int(feedback_max_images):]
|
||||
|
||||
params = FabricParams(
|
||||
enabled=(not feedback_disabled),
|
||||
|
|
@ -267,8 +273,18 @@ class FabricScript(modules.scripts.Script):
|
|||
feedback_during_high_res_fix=feedback_during_high_res_fix,
|
||||
)
|
||||
|
||||
|
||||
if use_feedback(params) or (DEBUG and not feedback_disabled):
|
||||
print("[FABRIC] Patching U-Net forward pass...")
|
||||
|
||||
# log the generation params to be displayed/stored as metadata
|
||||
log_params = asdict(params)
|
||||
del log_params["enabled"]
|
||||
log_params["pos_images"] = [pil_to_str(img) for img in log_params["pos_images"]]
|
||||
log_params["neg_images"] = [pil_to_str(img) for img in log_params["neg_images"]]
|
||||
log_params = {f"fabric/{k}": v for k, v in log_params.items()}
|
||||
p.extra_generation_params.update(log_params)
|
||||
|
||||
unet = p.sd_model.model.diffusion_model
|
||||
patch_unet_forward_pass(p, unet, params)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -46,14 +46,18 @@ def get_latents_from_params(p, params, width, height):
|
|||
return params.pos_latents, params.neg_latents
|
||||
|
||||
|
||||
def get_curr_feedback_weight(p, params):
|
||||
w = params.max_weight
|
||||
return w, w * params.neg_scale
|
||||
def get_curr_feedback_weight(p, params, timestep):
|
||||
progress = 1 - (timestep / (p.sd_model.num_timesteps - 1))
|
||||
if progress >= params.start and progress <= params.end:
|
||||
w = params.max_weight
|
||||
else:
|
||||
w = params.min_weight
|
||||
return max(0, w), max(0, w * params.neg_scale)
|
||||
|
||||
|
||||
def patch_unet_forward_pass(p, unet, params):
|
||||
if not params.pos_images and not params.neg_images:
|
||||
print("[FABRIC] No images to found, aborting patching")
|
||||
print("[FABRIC] No feedback images found, aborting patching")
|
||||
return
|
||||
|
||||
if not hasattr(unet, "_fabric_old_forward"):
|
||||
|
|
@ -74,8 +78,8 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
hr_w = (hr_w // 8) * 8
|
||||
hr_h = (hr_h // 8) * 8
|
||||
else:
|
||||
hr_h = width
|
||||
hr_w = height
|
||||
hr_w = width
|
||||
hr_h = height
|
||||
|
||||
def new_forward(self, x, timesteps=None, context=None, **kwargs):
|
||||
_, uncond_ids, context = unmark_prompt_context(context)
|
||||
|
|
@ -83,11 +87,16 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
has_cond = len(cond_ids) > 0
|
||||
has_uncond = len(uncond_ids) > 0
|
||||
|
||||
w_latent, h_latent = x.shape[-2:]
|
||||
h_latent, w_latent = x.shape[-2:]
|
||||
w, h = 8 * w_latent, 8 * h_latent
|
||||
if has_hires_fix and w == hr_w and h == hr_h:
|
||||
if not params.feedback_during_high_res_fix:
|
||||
print("[FABRIC] Skipping feedback during high-res fix")
|
||||
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
|
||||
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item())
|
||||
if pos_weight <= 0 and neg_weight <= 0:
|
||||
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
|
||||
pos_latents, neg_latents = get_latents_from_params(p, params, w, h)
|
||||
pos_latents = pos_latents if has_cond else []
|
||||
|
|
@ -148,7 +157,6 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
num_cond = len(cond_ids)
|
||||
num_uncond = len(uncond_ids)
|
||||
|
||||
pos_weight, neg_weight = get_curr_feedback_weight(p, params)
|
||||
|
||||
outs = []
|
||||
if num_cond > 0:
|
||||
|
|
@ -157,15 +165,15 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
ctx_cond = torch.cat([context[cond_ids], pos_hs], dim=1) # (n_cond, seq * (1 + n_pos), dim)
|
||||
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_pos),)
|
||||
ws[x_cond.size(1):] = pos_weight
|
||||
out_cond = weighted_attention(attn1._fabric_old_forward, x_cond, ctx_cond, ws, **kwargs) # (n_cond, seq, dim)
|
||||
out_cond = weighted_attention(attn1, attn1._fabric_old_forward, x_cond, ctx_cond, ws, **kwargs) # (n_cond, seq, dim)
|
||||
outs.append(out_cond)
|
||||
if num_uncond > 0:
|
||||
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(num_uncond, -1, -1) # (n_uncond, seq * n_neg, dim)
|
||||
x_uncond = x[uncond_ids] # (n_uncond, seq, dim)
|
||||
ctx_uncond = torch.cat([context[uncond_ids], neg_hs], dim=1) # (n_uncond, seq * (1 + n_neg), dim)
|
||||
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_neg),)
|
||||
ws[x_cond.size(1):] = neg_weight
|
||||
out_uncond = weighted_attention(attn1._fabric_old_forward, x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim)
|
||||
ws = torch.ones_like(ctx_uncond[0, :, 0]) # (seq * (1 + n_neg),)
|
||||
ws[x_uncond.size(1):] = neg_weight
|
||||
out_uncond = weighted_attention(attn1, attn1._fabric_old_forward, x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim)
|
||||
outs.append(out_uncond)
|
||||
out = torch.cat(outs, dim=0)
|
||||
return out
|
||||
|
|
|
|||
|
|
@ -35,10 +35,11 @@ def patched_einsum_op_compvis(q, k, v, weights=None):
|
|||
|
||||
|
||||
def patched_xformers_attn(q, k, v, attn_bias=None, op=None, weights=None, orig_attn=None):
|
||||
print(q.shape, v.shape, weights.shape)
|
||||
bs, nq, nh, dh = q.shape # batch_size, num_queries, num_heads, dim_per_head
|
||||
if weights is not None:
|
||||
min_val = torch.finfo(q.dtype).min
|
||||
w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(*q.shape[:3], -1).transpose(-2, -1)
|
||||
w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(bs, nh, nq, -1).contiguous()
|
||||
w_bias = w_bias.to(dtype=q.dtype)
|
||||
if attn_bias is None:
|
||||
attn_bias = w_bias
|
||||
else:
|
||||
|
|
@ -64,10 +65,6 @@ def patched_sdp_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, we
|
|||
def weighted_split_cross_attention_forward(self, x, context=None, mask=None, weights=None):
|
||||
h = self.heads
|
||||
|
||||
# OURS: normalize weights to preserve attention magnitude
|
||||
if weights is not None:
|
||||
weights = weights[None, None, :] / weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
|
|
@ -112,12 +109,13 @@ def weighted_split_cross_attention_forward(self, x, context=None, mask=None, wei
|
|||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
# OURS: apply weights to attention
|
||||
if weights is not None:
|
||||
s2 = s2 * weights
|
||||
bias = weights.to(s1.dtype).log().clamp(min=torch.finfo(s1.dtype).min)
|
||||
s1 = s1 + bias
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
|
@ -138,12 +136,9 @@ def is_the_same(fn1, fn2):
|
|||
return fn1.__name__ == fn2.__name__ and fn1.__module__ == fn2.__module__
|
||||
|
||||
|
||||
def weighted_attention(attn_fn, x, context=None, weights=None, **kwargs):
|
||||
def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs):
|
||||
if weights is None:
|
||||
return attn_fn(x, context=context, **kwargs)
|
||||
|
||||
print(attn_fn.__module__, attn_fn.__name__, type(attn_fn))
|
||||
print(split_cross_attention_forward_invokeAI.__module__, split_cross_attention_forward_invokeAI.__name__, type(split_cross_attention_forward_invokeAI))
|
||||
|
||||
if is_the_same(attn_fn, split_cross_attention_forward_invokeAI):
|
||||
modules.sd_hijack_optimizations.einsum_op_compvis = functools.partial(patched_einsum_op_compvis, weights=weights)
|
||||
|
|
@ -167,7 +162,7 @@ def weighted_attention(attn_fn, x, context=None, weights=None, **kwargs):
|
|||
return out
|
||||
|
||||
elif is_the_same(attn_fn, split_cross_attention_forward):
|
||||
return weighted_split_cross_attention_forward(x, context=context, weights=weights, **kwargs)
|
||||
return weighted_split_cross_attention_forward(self, x, context=context, weights=weights, **kwargs)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"FABRIC does not support `{attn_fn.__module__}.{attn_fn.__name__}` yet. Please choose a supported attention function.")
|
||||
|
|
|
|||
Loading…
Reference in New Issue