implement weighted attention for different processors
parent
10e76b321e
commit
0290e51ab0
|
|
@ -1,3 +1,5 @@
|
|||
.DS_Store
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from scripts.helpers import WebUiComponents
|
|||
|
||||
__version__ = "0.3.5"
|
||||
|
||||
DEBUG = False
|
||||
DEBUG = os.getenv("DEBUG", False)
|
||||
|
||||
if DEBUG:
|
||||
print(f"WARNING: Loading FABRIC v{__version__} in DEBUG mode")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from modules.processing import StableDiffusionProcessingTxt2Img
|
|||
from ldm.modules.attention import BasicTransformerBlock
|
||||
|
||||
from scripts.marking import patch_process_sample, unmark_prompt_context
|
||||
from scripts.weighted_attention import weighted_attention
|
||||
|
||||
|
||||
def encode_to_latent(p, image, w, h):
|
||||
|
|
@ -45,6 +46,11 @@ def get_latents_from_params(p, params, width, height):
|
|||
return params.pos_latents, params.neg_latents
|
||||
|
||||
|
||||
def get_curr_feedback_weight(p, params):
|
||||
w = params.max_weight
|
||||
return w, w * params.neg_scale
|
||||
|
||||
|
||||
def patch_unet_forward_pass(p, unet, params):
|
||||
if not params.pos_images and not params.neg_images:
|
||||
print("[FABRIC] No images to found, aborting patching")
|
||||
|
|
@ -53,7 +59,7 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
if not hasattr(unet, "_fabric_old_forward"):
|
||||
unet._fabric_old_forward = unet.forward
|
||||
|
||||
null_ctx = p.sd_model.get_learned_conditioning([""])
|
||||
null_ctx = p.sd_model.get_learned_conditioning([""]).to(devices.device, dtype=devices.dtype_unet)
|
||||
|
||||
width = (p.width // 8) * 8
|
||||
height = (p.height // 8) * 8
|
||||
|
|
@ -142,18 +148,24 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
num_cond = len(cond_ids)
|
||||
num_uncond = len(uncond_ids)
|
||||
|
||||
pos_weight, neg_weight = get_curr_feedback_weight(p, params)
|
||||
|
||||
outs = []
|
||||
if num_cond > 0:
|
||||
pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(num_cond, -1, -1) # (n_cond, seq * n_pos, dim)
|
||||
x_cond = x[cond_ids] # (n_cond, seq, dim)
|
||||
ctx_cond = torch.cat([context[cond_ids], pos_hs], dim=1) # (n_cond, seq * (1 + n_pos), dim)
|
||||
out_cond = attn1._fabric_old_forward(x_cond, ctx_cond, **kwargs) # (n_cond, seq, dim)
|
||||
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_pos),)
|
||||
ws[x_cond.size(1):] = pos_weight
|
||||
out_cond = weighted_attention(attn1._fabric_old_forward, x_cond, ctx_cond, ws, **kwargs) # (n_cond, seq, dim)
|
||||
outs.append(out_cond)
|
||||
if num_uncond > 0:
|
||||
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(num_uncond, -1, -1) # (n_uncond, seq * n_neg, dim)
|
||||
x_uncond = x[uncond_ids] # (n_uncond, seq, dim)
|
||||
ctx_uncond = torch.cat([context[uncond_ids], neg_hs], dim=1) # (n_uncond, seq * (1 + n_neg), dim)
|
||||
out_uncond = attn1._fabric_old_forward(x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim)
|
||||
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_neg),)
|
||||
ws[x_cond.size(1):] = neg_weight
|
||||
out_uncond = weighted_attention(attn1._fabric_old_forward, x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim)
|
||||
outs.append(out_uncond)
|
||||
out = torch.cat(outs, dim=0)
|
||||
return out
|
||||
|
|
|
|||
|
|
@ -0,0 +1,173 @@
|
|||
import math
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn.functional
|
||||
from torch import einsum
|
||||
from einops import rearrange
|
||||
|
||||
from ldm.util import default
|
||||
|
||||
import modules.sd_hijack_optimizations
|
||||
from modules import shared, devices
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.sd_hijack_optimizations import (
|
||||
split_cross_attention_forward_invokeAI,
|
||||
xformers_attention_forward,
|
||||
scaled_dot_product_no_mem_attention_forward,
|
||||
scaled_dot_product_attention_forward,
|
||||
split_cross_attention_forward,
|
||||
get_available_vram,
|
||||
)
|
||||
|
||||
|
||||
_einsum_op_compvis = modules.sd_hijack_optimizations.einsum_op_compvis
|
||||
|
||||
|
||||
def patched_einsum_op_compvis(q, k, v, weights=None):
|
||||
print("Calling patched einsum_op_compvis")
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
if weights is not None:
|
||||
print(s.shape, weights.shape)
|
||||
s = s * weights[None, None, :]
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
|
||||
def patched_xformers_attn(q, k, v, attn_bias=None, op=None, weights=None, orig_attn=None):
|
||||
print(q.shape, v.shape, weights.shape)
|
||||
if weights is not None:
|
||||
min_val = torch.finfo(q.dtype).min
|
||||
w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(*q.shape[:3], -1).transpose(-2, -1)
|
||||
if attn_bias is None:
|
||||
attn_bias = w_bias
|
||||
else:
|
||||
attn_bias += w_bias
|
||||
return orig_attn(q, k, v, attn_bias=attn_bias, op=op)
|
||||
|
||||
|
||||
def patched_sdp_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, weights=None, orig_attn=None):
|
||||
if attn_mask is not None:
|
||||
attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
|
||||
attn_mask = attn_mask.to(dtype=q.dtype)
|
||||
if weights is not None:
|
||||
min_val = torch.finfo(q.dtype).min
|
||||
w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(*q.shape[:3], -1)
|
||||
if attn_mask is None:
|
||||
attn_mask = w_bias
|
||||
else:
|
||||
attn_mask += w_bias
|
||||
return orig_attn(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
|
||||
|
||||
# copied and adapted from modules.sd_hijack_optimizations.split_cross_attention_forward
|
||||
def weighted_split_cross_attention_forward(self, x, context=None, mask=None, weights=None):
|
||||
h = self.heads
|
||||
|
||||
# OURS: normalize weights to preserve attention magnitude
|
||||
if weights is not None:
|
||||
weights = weights[None, None, :] / weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
dtype = q_in.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k_in = k_in * self.scale
|
||||
|
||||
del context, x
|
||||
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
# OURS: apply weights to attention
|
||||
if weights is not None:
|
||||
s2 = s2 * weights
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
del q, k, v
|
||||
|
||||
r1 = r1.to(dtype)
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
def is_the_same(fn1, fn2):
|
||||
if isinstance(fn2, (list, tuple)):
|
||||
return any(is_the_same(fn1, f) for f in fn2)
|
||||
return fn1.__name__ == fn2.__name__ and fn1.__module__ == fn2.__module__
|
||||
|
||||
|
||||
def weighted_attention(attn_fn, x, context=None, weights=None, **kwargs):
|
||||
if weights is None:
|
||||
return attn_fn(x, context=context, **kwargs)
|
||||
|
||||
print(attn_fn.__module__, attn_fn.__name__, type(attn_fn))
|
||||
print(split_cross_attention_forward_invokeAI.__module__, split_cross_attention_forward_invokeAI.__name__, type(split_cross_attention_forward_invokeAI))
|
||||
|
||||
if is_the_same(attn_fn, split_cross_attention_forward_invokeAI):
|
||||
modules.sd_hijack_optimizations.einsum_op_compvis = functools.partial(patched_einsum_op_compvis, weights=weights)
|
||||
out = attn_fn(x, context=context, **kwargs)
|
||||
modules.sd_hijack_optimizations.einsum_op_compvis = _einsum_op_compvis
|
||||
return out
|
||||
|
||||
elif is_the_same(attn_fn, xformers_attention_forward):
|
||||
import xformers.ops # xformers dependency is optional
|
||||
_memory_efficient_attention = xformers.ops.memory_efficient_attention
|
||||
xformers.ops.memory_efficient_attention = functools.partial(patched_xformers_attn, weights=weights, orig_attn=_memory_efficient_attention)
|
||||
out = attn_fn(x, context=context, **kwargs)
|
||||
xformers.ops.memory_efficient_attention = _memory_efficient_attention
|
||||
return out
|
||||
|
||||
elif is_the_same(attn_fn, [scaled_dot_product_no_mem_attention_forward, scaled_dot_product_attention_forward]):
|
||||
_sdp_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
torch.nn.functional.scaled_dot_product_attention = functools.partial(patched_sdp_attn, weights=weights, orig_attn=_sdp_attention)
|
||||
out = attn_fn(x, context=context, **kwargs)
|
||||
torch.nn.functional.scaled_dot_product_attention = _sdp_attention
|
||||
return out
|
||||
|
||||
elif is_the_same(attn_fn, split_cross_attention_forward):
|
||||
return weighted_split_cross_attention_forward(x, context=context, weights=weights, **kwargs)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"FABRIC does not support `{attn_fn.__module__}.{attn_fn.__name__}` yet. Please choose a supported attention function.")
|
||||
Loading…
Reference in New Issue