diff --git a/scripts/fabric.py b/scripts/fabric.py index f1a833d..069b69e 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -27,7 +27,7 @@ except ImportError: from modules.ui import create_refresh_button -__version__ = "0.5.1" +__version__ = "0.5.2" DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1") @@ -421,6 +421,9 @@ class FabricScript(modules.scripts.Script): # restore original U-Net forward pass in case previous batch errored out unpatch_unet_forward_pass(p.sd_model.model.diffusion_model) + if not feedback_enabled: + return + liked_paths = liked_paths[-int(feedback_max_images):] disliked_paths = disliked_paths[-int(feedback_max_images):] diff --git a/scripts/weighted_attention.py b/scripts/weighted_attention.py index 098a2b3..5e74209 100644 --- a/scripts/weighted_attention.py +++ b/scripts/weighted_attention.py @@ -1,5 +1,5 @@ import math -import functools +import psutil import torch import torch.nn.functional @@ -8,65 +8,257 @@ from einops import rearrange from ldm.util import default -import modules.sd_hijack_optimizations -from modules import shared, devices +from modules import shared, devices, sd_hijack from modules.hypernetworks import hypernetwork from modules.sd_hijack_optimizations import ( - split_cross_attention_forward_invokeAI, - xformers_attention_forward, - scaled_dot_product_no_mem_attention_forward, - scaled_dot_product_attention_forward, - split_cross_attention_forward, + get_xformers_flash_attention_op, get_available_vram, ) - try: + import xformers import xformers.ops - _xformers_attn = xformers.ops.memory_efficient_attention except ImportError: pass -_einsum_op_compvis = modules.sd_hijack_optimizations.einsum_op_compvis -_sdp_attention = torch.nn.functional.scaled_dot_product_attention + +def get_weighted_attn_fn(): + method = sd_hijack.model_hijack.optimization_method + if method is None: + return weighted_split_cross_attention_forward + method = method.lower() + + if method not in ['none', 'sdp-no-mem', 'sdp', 'xformers', 'sub-quadratic', 'v1', 'invokeai', 'doggettx']: + print(f"[FABRIC] Warning: Unknown attention optimization method {method}.") + return weighted_split_cross_attention_forward + + if method == 'none': + return weighted_split_cross_attention_forward + elif method == 'xformers': + return weighted_xformers_attention_forward + elif method == 'sdp-no-mem': + return weighted_scaled_dot_product_no_mem_attention_forward + elif method == 'sdp': + return weighted_scaled_dot_product_attention_forward + elif method == 'doggettx': + return weighted_split_cross_attention_forward + elif method == 'invokeai': + return weighted_split_cross_attention_forward_invokeAI + elif method == 'sub-quadratic': + print(f"[FABRIC] Warning: Sub-quadratic attention is not supported yet. Please open an issue if you need this for your workflow. Falling back to split attention.") + return weighted_split_cross_attention_forward + elif method == 'v1': + print(f"[FABRIC] Warning: V1 attention is not supported yet. Please open an issue if you need this for your workflow. Falling back to split attention.") + return weighted_split_cross_attention_forward + else: + return weighted_split_cross_attention_forward -def patched_einsum_op_compvis(q, k, v, weights=None): +def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs): + if weights is None: + return attn_fn(x, context=context, **kwargs) + + weighted_attn_fn = get_weighted_attn_fn() + return weighted_attn_fn(self, x, context=context, weights=weights, **kwargs) + + +def _get_attn_bias(weights, shape=None, dtype=torch.float32): + min_val = torch.finfo(dtype).min + w_bias = weights.log().clamp(min=min_val) + if shape is not None: + assert shape[-1] == w_bias.shape[-1], "Last dimension of shape must match last dimension of weights (number of keys)" + w_bias = w_bias.view([1] * (len(shape) - 1) + [-1]).expand(shape) + w_bias = w_bias.to(dtype=dtype) + return w_bias + +### The following attn functions are copied and adapted from modules.sd_hijack_optimizations + +# --- InvokeAI --- +mem_total_gb = psutil.virtual_memory().total // (1 << 30) + +def einsum_op_compvis(q, k, v, weights=None): s = einsum('b i d, b j d -> b i j', q, k) - s = s.softmax(dim=-1, dtype=s.dtype) if weights is not None: - s = s * weights[None, None, :] + s += _get_attn_bias(weights, s.shape, s.dtype) + s = s.softmax(dim=-1, dtype=s.dtype) return einsum('b i j, b j d -> b i d', s, v) +def einsum_op_slice_0(q, k, v, slice_size, weights=None): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end], weights) + return r -def patched_xformers_attn(q, k, v, attn_bias=None, op=None, weights=None, orig_attn=None): - bs, nq, nh, dh = q.shape # batch_size, num_queries, num_heads, dim_per_head +def einsum_op_slice_1(q, k, v, slice_size, weights=None): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v, weights) + return r + +def einsum_op_mps_v1(q, k, v, weights=None): + if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 + return einsum_op_compvis(q, k, v, weights) + else: + slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + if slice_size % 4096 == 0: + slice_size -= 1 + return einsum_op_slice_1(q, k, v, slice_size, weights) + +def einsum_op_mps_v2(q, k, v, weights=None): + if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: + return einsum_op_compvis(q, k, v, weights) + else: + return einsum_op_slice_0(q, k, v, 1, weights) + +def einsum_op_tensor_mem(q, k, v, max_tensor_mb, weights=None): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return einsum_op_compvis(q, k, v, weights) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return einsum_op_slice_0(q, k, v, q.shape[0] // div, weights) + return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1), weights) + +def einsum_op_cuda(q, k, v, weights=None): + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + # Divide factor of safety as there's copying and fragmentation + return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), weights) + +def einsum_op(q, k, v, weights=None): + if q.device.type == 'cuda': + return einsum_op_cuda(q, k, v, weights) + + if q.device.type == 'mps': + if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: + return einsum_op_mps_v1(q, k, v, weights) + return einsum_op_mps_v2(q, k, v, weights) + + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return einsum_op_tensor_mem(q, k, v, 32, weights) + +def weighted_split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, weights=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k = k * self.scale + + q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) + r = einsum_op(q, k, v, weights) + r = r.to(dtype) + return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) +# --- end InvokeAI --- + + +def weighted_xformers_attention_forward(self, x, context=None, mask=None, weights=None): + + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in)) + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() + + ### FABRIC ### + bias_shape = (q.size(0), q.size(2), q.size(1), k.size(1)) # (bs, h, nq, nk) if weights is not None: - min_val = torch.finfo(q.dtype).min - w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(bs, nh, nq, -1).contiguous() - w_bias = w_bias.to(dtype=q.dtype) - if attn_bias is None: - attn_bias = w_bias - else: - attn_bias += w_bias - return orig_attn(q, k, v, attn_bias=attn_bias, op=op) + attn_bias = _get_attn_bias(weights, bias_shape, dtype=q.dtype) + else: + attn_bias = None + ### END FABRIC ### + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=get_xformers_flash_attention_op(q, k, v)) + + out = out.to(dtype) + + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + return self.to_out(out) -def patched_sdp_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, weights=None, orig_attn=None): - if attn_mask is not None: - attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask - attn_mask = attn_mask.to(dtype=q.dtype) +def weighted_scaled_dot_product_attention_forward(self, x, context=None, mask=None, weights=None): + batch_size, sequence_length, inner_dim = x.shape + + if mask is not None: + mask = self.prepare_attention_mask(mask, sequence_length, batch_size) + mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) + + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + head_dim = inner_dim // h + q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() + + ### FABRIC ### + mask_shape = q.shape[:3] + (k.shape[2],) # (bs, h, nq, nk) + if mask is None: + mask = 0 + else: + mask.masked_fill(not mask, -float('inf')) if mask.dtype==torch.bool else mask + mask = mask.to(dtype=q.dtype) if weights is not None: - min_val = torch.finfo(q.dtype).min - w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(*q.shape[:3], -1) - if attn_mask is None: - attn_mask = w_bias - else: - attn_mask += w_bias - return orig_attn(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + w_bias = _get_attn_bias(weights, mask_shape, dtype=q.dtype) + mask += w_bias + ### END FABRIC ### + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) + hidden_states = hidden_states.to(dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + +def weighted_scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, weights=None): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + return weighted_scaled_dot_product_attention_forward(self, x, context, mask, weights) -# copied and adapted from modules.sd_hijack_optimizations.split_cross_attention_forward def weighted_split_cross_attention_forward(self, x, context=None, mask=None, weights=None): h = self.heads @@ -98,7 +290,7 @@ def weighted_split_cross_attention_forward(self, x, context=None, mask=None, wei modifier = 3 if q.element_size() == 2 else 2.5 mem_required = tensor_size * modifier - # FABRIC some batch-size dependend overhead. Found empirically on RTX 3090. + # FABRIC incurs some batch-size-dependend overhead. Found empirically on RTX 3090. bs = q.shape[0] / 8 # batch size mem_required *= 1/(bs + 1) + 1.25 mem_required *= 1.05 # safety margin @@ -139,45 +331,3 @@ def weighted_split_cross_attention_forward(self, x, context=None, mask=None, wei del r1 return self.to_out(r2) - - -def is_the_same(fn1, fn2): - if isinstance(fn2, (list, tuple)): - return any(is_the_same(fn1, f) for f in fn2) - return fn1.__name__ == fn2.__name__ and fn1.__module__ == fn2.__module__ - - -def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs): - if weights is None: - return attn_fn(x, context=context, **kwargs) - - if is_the_same(attn_fn, split_cross_attention_forward_invokeAI): - modules.sd_hijack_optimizations.einsum_op_compvis = functools.partial(patched_einsum_op_compvis, weights=weights) - try: - out = attn_fn(x, context=context, **kwargs) - finally: - modules.sd_hijack_optimizations.einsum_op_compvis = _einsum_op_compvis - return out - - elif is_the_same(attn_fn, xformers_attention_forward): - assert _xformers_attn in locals() or _xformers_attn in globals(), "xformers attention function not found" - xformers.ops.memory_efficient_attention = functools.partial(patched_xformers_attn, weights=weights, orig_attn=_xformers_attn) - try: - out = attn_fn(x, context=context, **kwargs) - finally: - xformers.ops.memory_efficient_attention = _xformers_attn - return out - - elif is_the_same(attn_fn, [scaled_dot_product_no_mem_attention_forward, scaled_dot_product_attention_forward]): - torch.nn.functional.scaled_dot_product_attention = functools.partial(patched_sdp_attn, weights=weights, orig_attn=_sdp_attention) - try: - out = attn_fn(x, context=context, **kwargs) - finally: - torch.nn.functional.scaled_dot_product_attention = _sdp_attention - return out - - elif is_the_same(attn_fn, split_cross_attention_forward): - return weighted_split_cross_attention_forward(self, x, context=context, weights=weights, **kwargs) - - else: - raise NotImplementedError(f"FABRIC does not support `{attn_fn.__module__}.{attn_fn.__name__}` yet. Please choose a supported attention function.")