v0.5.2 (more robust patching, fix negative weights)
parent
015d810c23
commit
0018b6be2d
|
|
@ -27,7 +27,7 @@ except ImportError:
|
|||
from modules.ui import create_refresh_button
|
||||
|
||||
|
||||
__version__ = "0.5.1"
|
||||
__version__ = "0.5.2"
|
||||
|
||||
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")
|
||||
|
||||
|
|
@ -421,6 +421,9 @@ class FabricScript(modules.scripts.Script):
|
|||
# restore original U-Net forward pass in case previous batch errored out
|
||||
unpatch_unet_forward_pass(p.sd_model.model.diffusion_model)
|
||||
|
||||
if not feedback_enabled:
|
||||
return
|
||||
|
||||
liked_paths = liked_paths[-int(feedback_max_images):]
|
||||
disliked_paths = disliked_paths[-int(feedback_max_images):]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import math
|
||||
import functools
|
||||
import psutil
|
||||
|
||||
import torch
|
||||
import torch.nn.functional
|
||||
|
|
@ -8,65 +8,257 @@ from einops import rearrange
|
|||
|
||||
from ldm.util import default
|
||||
|
||||
import modules.sd_hijack_optimizations
|
||||
from modules import shared, devices
|
||||
from modules import shared, devices, sd_hijack
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.sd_hijack_optimizations import (
|
||||
split_cross_attention_forward_invokeAI,
|
||||
xformers_attention_forward,
|
||||
scaled_dot_product_no_mem_attention_forward,
|
||||
scaled_dot_product_attention_forward,
|
||||
split_cross_attention_forward,
|
||||
get_xformers_flash_attention_op,
|
||||
get_available_vram,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
_xformers_attn = xformers.ops.memory_efficient_attention
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_einsum_op_compvis = modules.sd_hijack_optimizations.einsum_op_compvis
|
||||
_sdp_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
def get_weighted_attn_fn():
|
||||
method = sd_hijack.model_hijack.optimization_method
|
||||
if method is None:
|
||||
return weighted_split_cross_attention_forward
|
||||
method = method.lower()
|
||||
|
||||
if method not in ['none', 'sdp-no-mem', 'sdp', 'xformers', 'sub-quadratic', 'v1', 'invokeai', 'doggettx']:
|
||||
print(f"[FABRIC] Warning: Unknown attention optimization method {method}.")
|
||||
return weighted_split_cross_attention_forward
|
||||
|
||||
if method == 'none':
|
||||
return weighted_split_cross_attention_forward
|
||||
elif method == 'xformers':
|
||||
return weighted_xformers_attention_forward
|
||||
elif method == 'sdp-no-mem':
|
||||
return weighted_scaled_dot_product_no_mem_attention_forward
|
||||
elif method == 'sdp':
|
||||
return weighted_scaled_dot_product_attention_forward
|
||||
elif method == 'doggettx':
|
||||
return weighted_split_cross_attention_forward
|
||||
elif method == 'invokeai':
|
||||
return weighted_split_cross_attention_forward_invokeAI
|
||||
elif method == 'sub-quadratic':
|
||||
print(f"[FABRIC] Warning: Sub-quadratic attention is not supported yet. Please open an issue if you need this for your workflow. Falling back to split attention.")
|
||||
return weighted_split_cross_attention_forward
|
||||
elif method == 'v1':
|
||||
print(f"[FABRIC] Warning: V1 attention is not supported yet. Please open an issue if you need this for your workflow. Falling back to split attention.")
|
||||
return weighted_split_cross_attention_forward
|
||||
else:
|
||||
return weighted_split_cross_attention_forward
|
||||
|
||||
|
||||
def patched_einsum_op_compvis(q, k, v, weights=None):
|
||||
def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs):
|
||||
if weights is None:
|
||||
return attn_fn(x, context=context, **kwargs)
|
||||
|
||||
weighted_attn_fn = get_weighted_attn_fn()
|
||||
return weighted_attn_fn(self, x, context=context, weights=weights, **kwargs)
|
||||
|
||||
|
||||
def _get_attn_bias(weights, shape=None, dtype=torch.float32):
|
||||
min_val = torch.finfo(dtype).min
|
||||
w_bias = weights.log().clamp(min=min_val)
|
||||
if shape is not None:
|
||||
assert shape[-1] == w_bias.shape[-1], "Last dimension of shape must match last dimension of weights (number of keys)"
|
||||
w_bias = w_bias.view([1] * (len(shape) - 1) + [-1]).expand(shape)
|
||||
w_bias = w_bias.to(dtype=dtype)
|
||||
return w_bias
|
||||
|
||||
### The following attn functions are copied and adapted from modules.sd_hijack_optimizations
|
||||
|
||||
# --- InvokeAI ---
|
||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
def einsum_op_compvis(q, k, v, weights=None):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
if weights is not None:
|
||||
s = s * weights[None, None, :]
|
||||
s += _get_attn_bias(weights, s.shape, s.dtype)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
def einsum_op_slice_0(q, k, v, slice_size, weights=None):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end], weights)
|
||||
return r
|
||||
|
||||
def patched_xformers_attn(q, k, v, attn_bias=None, op=None, weights=None, orig_attn=None):
|
||||
bs, nq, nh, dh = q.shape # batch_size, num_queries, num_heads, dim_per_head
|
||||
def einsum_op_slice_1(q, k, v, slice_size, weights=None):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v, weights)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(q, k, v, weights=None):
|
||||
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||
return einsum_op_compvis(q, k, v, weights)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
if slice_size % 4096 == 0:
|
||||
slice_size -= 1
|
||||
return einsum_op_slice_1(q, k, v, slice_size, weights)
|
||||
|
||||
def einsum_op_mps_v2(q, k, v, weights=None):
|
||||
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||
return einsum_op_compvis(q, k, v, weights)
|
||||
else:
|
||||
return einsum_op_slice_0(q, k, v, 1, weights)
|
||||
|
||||
def einsum_op_tensor_mem(q, k, v, max_tensor_mb, weights=None):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return einsum_op_compvis(q, k, v, weights)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return einsum_op_slice_0(q, k, v, q.shape[0] // div, weights)
|
||||
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1), weights)
|
||||
|
||||
def einsum_op_cuda(q, k, v, weights=None):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), weights)
|
||||
|
||||
def einsum_op(q, k, v, weights=None):
|
||||
if q.device.type == 'cuda':
|
||||
return einsum_op_cuda(q, k, v, weights)
|
||||
|
||||
if q.device.type == 'mps':
|
||||
if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
|
||||
return einsum_op_mps_v1(q, k, v, weights)
|
||||
return einsum_op_mps_v2(q, k, v, weights)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return einsum_op_tensor_mem(q, k, v, 32, weights)
|
||||
|
||||
def weighted_split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, weights=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k = k * self.scale
|
||||
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||
r = einsum_op(q, k, v, weights)
|
||||
r = r.to(dtype)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
# --- end InvokeAI ---
|
||||
|
||||
|
||||
def weighted_xformers_attention_forward(self, x, context=None, mask=None, weights=None):
|
||||
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
### FABRIC ###
|
||||
bias_shape = (q.size(0), q.size(2), q.size(1), k.size(1)) # (bs, h, nq, nk)
|
||||
if weights is not None:
|
||||
min_val = torch.finfo(q.dtype).min
|
||||
w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(bs, nh, nq, -1).contiguous()
|
||||
w_bias = w_bias.to(dtype=q.dtype)
|
||||
if attn_bias is None:
|
||||
attn_bias = w_bias
|
||||
else:
|
||||
attn_bias += w_bias
|
||||
return orig_attn(q, k, v, attn_bias=attn_bias, op=op)
|
||||
attn_bias = _get_attn_bias(weights, bias_shape, dtype=q.dtype)
|
||||
else:
|
||||
attn_bias = None
|
||||
### END FABRIC ###
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=get_xformers_flash_attention_op(q, k, v))
|
||||
|
||||
out = out.to(dtype)
|
||||
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
def patched_sdp_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, weights=None, orig_attn=None):
|
||||
if attn_mask is not None:
|
||||
attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
|
||||
attn_mask = attn_mask.to(dtype=q.dtype)
|
||||
def weighted_scaled_dot_product_attention_forward(self, x, context=None, mask=None, weights=None):
|
||||
batch_size, sequence_length, inner_dim = x.shape
|
||||
|
||||
if mask is not None:
|
||||
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
|
||||
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
|
||||
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
head_dim = inner_dim // h
|
||||
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
### FABRIC ###
|
||||
mask_shape = q.shape[:3] + (k.shape[2],) # (bs, h, nq, nk)
|
||||
if mask is None:
|
||||
mask = 0
|
||||
else:
|
||||
mask.masked_fill(not mask, -float('inf')) if mask.dtype==torch.bool else mask
|
||||
mask = mask.to(dtype=q.dtype)
|
||||
if weights is not None:
|
||||
min_val = torch.finfo(q.dtype).min
|
||||
w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(*q.shape[:3], -1)
|
||||
if attn_mask is None:
|
||||
attn_mask = w_bias
|
||||
else:
|
||||
attn_mask += w_bias
|
||||
return orig_attn(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
w_bias = _get_attn_bias(weights, mask_shape, dtype=q.dtype)
|
||||
mask += w_bias
|
||||
### END FABRIC ###
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def weighted_scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, weights=None):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
return weighted_scaled_dot_product_attention_forward(self, x, context, mask, weights)
|
||||
|
||||
|
||||
# copied and adapted from modules.sd_hijack_optimizations.split_cross_attention_forward
|
||||
def weighted_split_cross_attention_forward(self, x, context=None, mask=None, weights=None):
|
||||
h = self.heads
|
||||
|
||||
|
|
@ -98,7 +290,7 @@ def weighted_split_cross_attention_forward(self, x, context=None, mask=None, wei
|
|||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
|
||||
# FABRIC some batch-size dependend overhead. Found empirically on RTX 3090.
|
||||
# FABRIC incurs some batch-size-dependend overhead. Found empirically on RTX 3090.
|
||||
bs = q.shape[0] / 8 # batch size
|
||||
mem_required *= 1/(bs + 1) + 1.25
|
||||
mem_required *= 1.05 # safety margin
|
||||
|
|
@ -139,45 +331,3 @@ def weighted_split_cross_attention_forward(self, x, context=None, mask=None, wei
|
|||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
def is_the_same(fn1, fn2):
|
||||
if isinstance(fn2, (list, tuple)):
|
||||
return any(is_the_same(fn1, f) for f in fn2)
|
||||
return fn1.__name__ == fn2.__name__ and fn1.__module__ == fn2.__module__
|
||||
|
||||
|
||||
def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs):
|
||||
if weights is None:
|
||||
return attn_fn(x, context=context, **kwargs)
|
||||
|
||||
if is_the_same(attn_fn, split_cross_attention_forward_invokeAI):
|
||||
modules.sd_hijack_optimizations.einsum_op_compvis = functools.partial(patched_einsum_op_compvis, weights=weights)
|
||||
try:
|
||||
out = attn_fn(x, context=context, **kwargs)
|
||||
finally:
|
||||
modules.sd_hijack_optimizations.einsum_op_compvis = _einsum_op_compvis
|
||||
return out
|
||||
|
||||
elif is_the_same(attn_fn, xformers_attention_forward):
|
||||
assert _xformers_attn in locals() or _xformers_attn in globals(), "xformers attention function not found"
|
||||
xformers.ops.memory_efficient_attention = functools.partial(patched_xformers_attn, weights=weights, orig_attn=_xformers_attn)
|
||||
try:
|
||||
out = attn_fn(x, context=context, **kwargs)
|
||||
finally:
|
||||
xformers.ops.memory_efficient_attention = _xformers_attn
|
||||
return out
|
||||
|
||||
elif is_the_same(attn_fn, [scaled_dot_product_no_mem_attention_forward, scaled_dot_product_attention_forward]):
|
||||
torch.nn.functional.scaled_dot_product_attention = functools.partial(patched_sdp_attn, weights=weights, orig_attn=_sdp_attention)
|
||||
try:
|
||||
out = attn_fn(x, context=context, **kwargs)
|
||||
finally:
|
||||
torch.nn.functional.scaled_dot_product_attention = _sdp_attention
|
||||
return out
|
||||
|
||||
elif is_the_same(attn_fn, split_cross_attention_forward):
|
||||
return weighted_split_cross_attention_forward(self, x, context=context, weights=weights, **kwargs)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"FABRIC does not support `{attn_fn.__module__}.{attn_fn.__name__}` yet. Please choose a supported attention function.")
|
||||
|
|
|
|||
Loading…
Reference in New Issue