diff --git a/.gitignore b/.gitignore index 68bc17f..98fd9f0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.DS_Store + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/scripts/fabric.py b/scripts/fabric.py index 26e68a9..2761f6d 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -17,7 +17,7 @@ from scripts.helpers import WebUiComponents __version__ = "0.3.5" -DEBUG = False +DEBUG = os.getenv("DEBUG", False) if DEBUG: print(f"WARNING: Loading FABRIC v{__version__} in DEBUG mode") diff --git a/scripts/patching.py b/scripts/patching.py index e7e468f..3472991 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -7,6 +7,7 @@ from modules.processing import StableDiffusionProcessingTxt2Img from ldm.modules.attention import BasicTransformerBlock from scripts.marking import patch_process_sample, unmark_prompt_context +from scripts.weighted_attention import weighted_attention def encode_to_latent(p, image, w, h): @@ -45,6 +46,11 @@ def get_latents_from_params(p, params, width, height): return params.pos_latents, params.neg_latents +def get_curr_feedback_weight(p, params): + w = params.max_weight + return w, w * params.neg_scale + + def patch_unet_forward_pass(p, unet, params): if not params.pos_images and not params.neg_images: print("[FABRIC] No images to found, aborting patching") @@ -53,7 +59,7 @@ def patch_unet_forward_pass(p, unet, params): if not hasattr(unet, "_fabric_old_forward"): unet._fabric_old_forward = unet.forward - null_ctx = p.sd_model.get_learned_conditioning([""]) + null_ctx = p.sd_model.get_learned_conditioning([""]).to(devices.device, dtype=devices.dtype_unet) width = (p.width // 8) * 8 height = (p.height // 8) * 8 @@ -142,18 +148,24 @@ def patch_unet_forward_pass(p, unet, params): num_cond = len(cond_ids) num_uncond = len(uncond_ids) + pos_weight, neg_weight = get_curr_feedback_weight(p, params) + outs = [] if num_cond > 0: pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(num_cond, -1, -1) # (n_cond, seq * n_pos, dim) x_cond = x[cond_ids] # (n_cond, seq, dim) ctx_cond = torch.cat([context[cond_ids], pos_hs], dim=1) # (n_cond, seq * (1 + n_pos), dim) - out_cond = attn1._fabric_old_forward(x_cond, ctx_cond, **kwargs) # (n_cond, seq, dim) + ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_pos),) + ws[x_cond.size(1):] = pos_weight + out_cond = weighted_attention(attn1._fabric_old_forward, x_cond, ctx_cond, ws, **kwargs) # (n_cond, seq, dim) outs.append(out_cond) if num_uncond > 0: neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(num_uncond, -1, -1) # (n_uncond, seq * n_neg, dim) x_uncond = x[uncond_ids] # (n_uncond, seq, dim) ctx_uncond = torch.cat([context[uncond_ids], neg_hs], dim=1) # (n_uncond, seq * (1 + n_neg), dim) - out_uncond = attn1._fabric_old_forward(x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim) + ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_neg),) + ws[x_cond.size(1):] = neg_weight + out_uncond = weighted_attention(attn1._fabric_old_forward, x_uncond, ctx_uncond, **kwargs) # (n_uncond, seq, dim) outs.append(out_uncond) out = torch.cat(outs, dim=0) return out diff --git a/scripts/weighted_attention.py b/scripts/weighted_attention.py new file mode 100644 index 0000000..5643d27 --- /dev/null +++ b/scripts/weighted_attention.py @@ -0,0 +1,173 @@ +import math +import functools + +import torch +import torch.nn.functional +from torch import einsum +from einops import rearrange + +from ldm.util import default + +import modules.sd_hijack_optimizations +from modules import shared, devices +from modules.hypernetworks import hypernetwork +from modules.sd_hijack_optimizations import ( + split_cross_attention_forward_invokeAI, + xformers_attention_forward, + scaled_dot_product_no_mem_attention_forward, + scaled_dot_product_attention_forward, + split_cross_attention_forward, + get_available_vram, +) + + +_einsum_op_compvis = modules.sd_hijack_optimizations.einsum_op_compvis + + +def patched_einsum_op_compvis(q, k, v, weights=None): + print("Calling patched einsum_op_compvis") + s = einsum('b i d, b j d -> b i j', q, k) + s = s.softmax(dim=-1, dtype=s.dtype) + if weights is not None: + print(s.shape, weights.shape) + s = s * weights[None, None, :] + return einsum('b i j, b j d -> b i d', s, v) + + +def patched_xformers_attn(q, k, v, attn_bias=None, op=None, weights=None, orig_attn=None): + print(q.shape, v.shape, weights.shape) + if weights is not None: + min_val = torch.finfo(q.dtype).min + w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(*q.shape[:3], -1).transpose(-2, -1) + if attn_bias is None: + attn_bias = w_bias + else: + attn_bias += w_bias + return orig_attn(q, k, v, attn_bias=attn_bias, op=op) + + +def patched_sdp_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, weights=None, orig_attn=None): + if attn_mask is not None: + attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask + attn_mask = attn_mask.to(dtype=q.dtype) + if weights is not None: + min_val = torch.finfo(q.dtype).min + w_bias = weights.log().clamp(min=min_val)[None, None, None, :].expand(*q.shape[:3], -1) + if attn_mask is None: + attn_mask = w_bias + else: + attn_mask += w_bias + return orig_attn(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + + +# copied and adapted from modules.sd_hijack_optimizations.split_cross_attention_forward +def weighted_split_cross_attention_forward(self, x, context=None, mask=None, weights=None): + h = self.heads + + # OURS: normalize weights to preserve attention magnitude + if weights is not None: + weights = weights[None, None, :] / weights.sum(dim=-1, keepdim=True) + + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + dtype = q_in.dtype + if shared.opts.upcast_attn: + q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k_in = k_in * self.scale + + del context, x + + q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + mem_free_total = get_available_vram() + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + # OURS: apply weights to attention + if weights is not None: + s2 = s2 * weights + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r1 = r1.to(dtype) + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +def is_the_same(fn1, fn2): + if isinstance(fn2, (list, tuple)): + return any(is_the_same(fn1, f) for f in fn2) + return fn1.__name__ == fn2.__name__ and fn1.__module__ == fn2.__module__ + + +def weighted_attention(attn_fn, x, context=None, weights=None, **kwargs): + if weights is None: + return attn_fn(x, context=context, **kwargs) + + print(attn_fn.__module__, attn_fn.__name__, type(attn_fn)) + print(split_cross_attention_forward_invokeAI.__module__, split_cross_attention_forward_invokeAI.__name__, type(split_cross_attention_forward_invokeAI)) + + if is_the_same(attn_fn, split_cross_attention_forward_invokeAI): + modules.sd_hijack_optimizations.einsum_op_compvis = functools.partial(patched_einsum_op_compvis, weights=weights) + out = attn_fn(x, context=context, **kwargs) + modules.sd_hijack_optimizations.einsum_op_compvis = _einsum_op_compvis + return out + + elif is_the_same(attn_fn, xformers_attention_forward): + import xformers.ops # xformers dependency is optional + _memory_efficient_attention = xformers.ops.memory_efficient_attention + xformers.ops.memory_efficient_attention = functools.partial(patched_xformers_attn, weights=weights, orig_attn=_memory_efficient_attention) + out = attn_fn(x, context=context, **kwargs) + xformers.ops.memory_efficient_attention = _memory_efficient_attention + return out + + elif is_the_same(attn_fn, [scaled_dot_product_no_mem_attention_forward, scaled_dot_product_attention_forward]): + _sdp_attention = torch.nn.functional.scaled_dot_product_attention + torch.nn.functional.scaled_dot_product_attention = functools.partial(patched_sdp_attn, weights=weights, orig_attn=_sdp_attention) + out = attn_fn(x, context=context, **kwargs) + torch.nn.functional.scaled_dot_product_attention = _sdp_attention + return out + + elif is_the_same(attn_fn, split_cross_attention_forward): + return weighted_split_cross_attention_forward(x, context=context, weights=weights, **kwargs) + + else: + raise NotImplementedError(f"FABRIC does not support `{attn_fn.__module__}.{attn_fn.__name__}` yet. Please choose a supported attention function.")