implement weighted attention for different processors

pull/23/head
Dimitri 2023-07-24 22:29:24 +02:00
parent 10e76b321e
commit 0290e51ab0
4 changed files with 191 additions and 4 deletions

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@ -17,7 +17,7 @@ from scripts.helpers import WebUiComponents
__version__ = "0.3.5"
DEBUG = False
DEBUG = os.getenv("DEBUG", False)
if DEBUG:
print(f"WARNING: Loading FABRIC v{__version__} in DEBUG mode")

View File

@ -7,6 +7,7 @@ from modules.processing import StableDiffusionProcessingTxt2Img
from ldm.modules.attention import BasicTransformerBlock
from scripts.marking import patch_process_sample, unmark_prompt_context
from scripts.weighted_attention import weighted_attention
def encode_to_latent(p, image, w, h):
@ -45,6 +46,11 @@ def get_latents_from_params(p, params, width, height):
return params.pos_latents, params.neg_latents
def get_curr_feedback_weight(p, params):
w = params.max_weight
return w, w * params.neg_scale
def patch_unet_forward_pass(p, unet, params):
if not params.pos_images and not params.neg_images:
print("[FABRIC] No images to found, aborting patching")
@ -53,7 +59,7 @@ def patch_unet_forward_pass(p, unet, params):
if not hasattr(unet, "_fabric_old_forward"):
unet._fabric_old_forward = unet.forward
null_ctx = p.sd_model.get_learned_conditioning([""])
null_ctx = p.sd_model.get_learned_conditioning([""]).to(devices.device, dtype=devices.dtype_unet)
width = (p.width // 8) * 8
height = (p.height // 8) * 8
@ -142,18 +148,24 @@ def patch_unet_forward_pass(p, unet, params):
num_cond = len(cond_ids)
num_uncond = len(uncond_ids)
pos_weight, neg_weight = get_curr_feedback_weight(p, params)
outs = []
if num_cond > 0:
pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(num_cond, -1, -1) # (n_cond, seq * n_pos, dim)
x_cond = x[cond_ids] # (n_cond, seq, dim)
ctx_cond = torch.cat([context[cond_ids], pos_hs], dim=1) # (n_cond, seq * (1 + n_pos), dim)
out_cond = attn1._fabric_old_forward(x_cond, ctx_cond, **kwargs) # (n_cond, seq, dim)
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_pos),)
ws[x_cond.size(1):] = pos_weight
out_cond = weighted_attention(attn1._fabric_old_forward, x_cond, ctx_cond, ws, **kwargs) # (n_cond, seq, dim)
outs.append(out_cond)
if num_uncond > 0:
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(num_uncond, -1, -1) # (n_uncond, seq * n_neg, dim)
x_uncond = x[uncond_ids] # (n_uncond, seq, dim)
ctx_uncond = torch.cat([context[uncond_ids], neg_hs], dim=1) # (n_uncond, seq * (1 + n_neg), dim)
out_uncond = attn1._fabric_old_forward(x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim)
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_neg),)
ws[x_cond.size(1):] = neg_weight
out_uncond = weighted_attention(attn1._fabric_old_forward, x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim)
outs.append(out_uncond)
out = torch.cat(outs, dim=0)
return out

View File

@ -0,0 +1,173 @@
import math
import functools
import torch
import torch.nn.functional
from torch import einsum
from einops import rearrange
from ldm.util import default
import modules.sd_hijack_optimizations
from modules import shared, devices
from modules.hypernetworks import hypernetwork
from modules.sd_hijack_optimizations import (
split_cross_attention_forward_invokeAI,
xformers_attention_forward,
scaled_dot_product_no_mem_attention_forward,
scaled_dot_product_attention_forward,
split_cross_attention_forward,
get_available_vram,
)
_einsum_op_compvis = modules.sd_hijack_optimizations.einsum_op_compvis
def patched_einsum_op_compvis(q, k, v, weights=None):
print("Calling patched einsum_op_compvis")
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
if weights is not None:
print(s.shape, weights.shape)
s = s * weights[None, None, :]
return einsum('b i j, b j d -> b i d', s, v)
def patched_xformers_attn(q, k, v, attn_bias=None, op=None, weights=None, orig_attn=None):
print(q.shape, v.shape, weights.shape)
if weights is not None:
min_val = torch.finfo(q.dtype).min
w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(*q.shape[:3], -1).transpose(-2, -1)
if attn_bias is None:
attn_bias = w_bias
else:
attn_bias += w_bias
return orig_attn(q, k, v, attn_bias=attn_bias, op=op)
def patched_sdp_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, weights=None, orig_attn=None):
if attn_mask is not None:
attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_mask = attn_mask.to(dtype=q.dtype)
if weights is not None:
min_val = torch.finfo(q.dtype).min
w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(*q.shape[:3], -1)
if attn_mask is None:
attn_mask = w_bias
else:
attn_mask += w_bias
return orig_attn(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
# copied and adapted from modules.sd_hijack_optimizations.split_cross_attention_forward
def weighted_split_cross_attention_forward(self, x, context=None, mask=None, weights=None):
h = self.heads
# OURS: normalize weights to preserve attention magnitude
if weights is not None:
weights = weights[None, None, :] / weights.sum(dim=-1, keepdim=True)
q_in = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
dtype = q_in.dtype
if shared.opts.upcast_attn:
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k_in = k_in * self.scale
del context, x
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
mem_free_total = get_available_vram()
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
# OURS: apply weights to attention
if weights is not None:
s2 = s2 * weights
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r1 = r1.to(dtype)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def is_the_same(fn1, fn2):
if isinstance(fn2, (list, tuple)):
return any(is_the_same(fn1, f) for f in fn2)
return fn1.__name__ == fn2.__name__ and fn1.__module__ == fn2.__module__
def weighted_attention(attn_fn, x, context=None, weights=None, **kwargs):
if weights is None:
return attn_fn(x, context=context, **kwargs)
print(attn_fn.__module__, attn_fn.__name__, type(attn_fn))
print(split_cross_attention_forward_invokeAI.__module__, split_cross_attention_forward_invokeAI.__name__, type(split_cross_attention_forward_invokeAI))
if is_the_same(attn_fn, split_cross_attention_forward_invokeAI):
modules.sd_hijack_optimizations.einsum_op_compvis = functools.partial(patched_einsum_op_compvis, weights=weights)
out = attn_fn(x, context=context, **kwargs)
modules.sd_hijack_optimizations.einsum_op_compvis = _einsum_op_compvis
return out
elif is_the_same(attn_fn, xformers_attention_forward):
import xformers.ops # xformers dependency is optional
_memory_efficient_attention = xformers.ops.memory_efficient_attention
xformers.ops.memory_efficient_attention = functools.partial(patched_xformers_attn, weights=weights, orig_attn=_memory_efficient_attention)
out = attn_fn(x, context=context, **kwargs)
xformers.ops.memory_efficient_attention = _memory_efficient_attention
return out
elif is_the_same(attn_fn, [scaled_dot_product_no_mem_attention_forward, scaled_dot_product_attention_forward]):
_sdp_attention = torch.nn.functional.scaled_dot_product_attention
torch.nn.functional.scaled_dot_product_attention = functools.partial(patched_sdp_attn, weights=weights, orig_attn=_sdp_attention)
out = attn_fn(x, context=context, **kwargs)
torch.nn.functional.scaled_dot_product_attention = _sdp_attention
return out
elif is_the_same(attn_fn, split_cross_attention_forward):
return weighted_split_cross_attention_forward(x, context=context, weights=weights, **kwargs)
else:
raise NotImplementedError(f"FABRIC does not support `{attn_fn.__module__}.{attn_fn.__name__}` yet. Please choose a supported attention function.")