diff --git a/README.md b/README.md index 9e3de1c..b4c9ebd 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,8 @@ ComfyUI node (by [@ssitu](https://github.com/ssitu)): https://github.com/ssitu/C ![demo](static/fabric_demo.gif) -## Releases +## Releases and Changelog +- [07.03.2024] 🔨 v0.6.5: Fixes compatibility with [WebUI Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge). - [07.03.2024] ✨ v0.6.4: SDXL support has been added. For optimal results, lowering the feedback strength is recommended (0.5 seems to be a good starting point). - [29.08.2023] 🏎️ v0.6.0: Up to 2x faster and 4x less VRAM usage thanks to [Token Merging](https://github.com/dbolya/tomesd/tree/main) (tested with 16 feedback images and a batch size of 4), moderate gains for fewer feedback images (10% speedup for 2 images, 30% for 8 images). Enable the Token Merging option to take advantage of this. - [22.08.2023] 🗃️ v0.5.0: Adds support for presets. Makes generated images using FABRIC more reproducible by loading the correct (previously used) feedback images when using "send to text2img/img2img". diff --git a/scripts/fabric.py b/scripts/fabric.py index 3a9565f..1cd1117 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -27,7 +27,7 @@ except ImportError: from modules.ui import create_refresh_button -__version__ = "0.6.4" +__version__ = "0.6.5" DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1") diff --git a/scripts/patching.py b/scripts/patching.py index 58073ac..1832cd4 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -17,6 +17,13 @@ from scripts.fabric_utils import image_hash from scripts.weighted_attention import weighted_attention from scripts.merging import compute_merge +try: + import ldm_patched + has_webui_forge = True + print("[FABRIC] Detected WebUI Forge, running in compatibility mode.") +except ImportError: + has_webui_forge = False + SD15 = "sd15" SDXL = "sdxl" @@ -43,10 +50,12 @@ def encode_to_latent(p, image, w, h): return z.squeeze(0) def forward_noise(p, x_0, t, noise=None): + device = x_0.device if noise is None: noise = torch.randn_like(x_0) - sqrt_alpha_bar_t = extract_into_tensor(p.sd_model.alphas_cumprod.sqrt(), t, x_0.shape) - sqrt_one_minus_alpha_bar_t = extract_into_tensor((1.0 - p.sd_model.alphas_cumprod).sqrt(), t, x_0.shape) + alpha_bar = p.sd_model.alphas_cumprod.to(device) + sqrt_alpha_bar_t = extract_into_tensor(alpha_bar.sqrt(), t, x_0.shape) + sqrt_one_minus_alpha_bar_t = extract_into_tensor((1.0 - alpha_bar).sqrt(), t, x_0.shape) x_t = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise return x_t @@ -96,13 +105,19 @@ def patch_unet_forward_pass(p, unet, params): if isinstance(p.sd_model, LatentDiffusion): sd_version = SD15 num_timesteps = p.sd_model.num_timesteps - BasicTransformerBlock = ldm.modules.attention.BasicTransformerBlock elif isinstance(p.sd_model, DiffusionEngine): sd_version = SDXL num_timesteps = len(p.sd_model.alphas_cumprod) - BasicTransformerBlock = sgm.modules.attention.BasicTransformerBlock else: raise ValueError(f"[FABRIC] Unsupported SD model: {type(p.sd_model)}") + + transformer_block_type = tuple( + [ + ldm.modules.attention.BasicTransformerBlock, # SD 1.5 + sgm.modules.attention.BasicTransformerBlock, # SDXL + ] + + ([ldm_patched.ldm.modules.attention.BasicTransformerBlock] if has_webui_forge else []) + ) batch_size = p.batch_size @@ -173,20 +188,13 @@ def patch_unet_forward_pass(p, unet, params): return self._fabric_old_forward(x, timesteps, context, **kwargs) # add noise to reference latents - if sd_version == SD15: - all_zs = [] - for latent in all_latents: - z = p.sd_model.q_sample(latent.unsqueeze(0), timesteps[0].unsqueeze(0))[0] - all_zs.append(z) - all_zs = torch.stack(all_zs, dim=0) - else: # SDXL - xs_0 = torch.stack(all_latents, dim=0) - ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,) - all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long()) + xs_0 = torch.stack(all_latents, dim=0) + ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,) + all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long()) # save original forward pass for module in self.modules(): - if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"): + if isinstance(module, transformer_block_type) and not hasattr(module.attn1, "_fabric_old_forward"): module.attn1._fabric_old_forward = module.attn1.forward module.attn2._fabric_old_forward = module.attn2.forward @@ -214,7 +222,7 @@ def patch_unet_forward_pass(p, unet, params): # patch forward pass to cache hidden states layer_idx = 0 for module in self.modules(): - if isinstance(module, BasicTransformerBlock): + if isinstance(module, transformer_block_type): module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2) layer_idx += 1 @@ -260,20 +268,19 @@ def patch_unet_forward_pass(p, unet, params): weights = None return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim) - outs = [] + out = torch.zeros_like(x) if num_cond > 0: out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim) - outs.append(out_cond) + out[cond_ids] = out_cond if num_uncond > 0: out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim) - outs.append(out_uncond) - out = torch.cat(outs, dim=0) + out[uncond_ids] = out_uncond return out # patch forward pass to inject cached hidden states layer_idx = 0 for module in self.modules(): - if isinstance(module, BasicTransformerBlock): + if isinstance(module, transformer_block_type): module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) layer_idx += 1 @@ -283,10 +290,10 @@ def patch_unet_forward_pass(p, unet, params): finally: # restore original pass for module in self.modules(): - if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"): + if isinstance(module, transformer_block_type) and hasattr(module.attn1, "_fabric_old_forward"): module.attn1.forward = module.attn1._fabric_old_forward del module.attn1._fabric_old_forward - if isinstance(module, BasicTransformerBlock) and hasattr(module.attn2, "_fabric_old_forward"): + if isinstance(module, transformer_block_type) and hasattr(module.attn2, "_fabric_old_forward"): module.attn2.forward = module.attn2._fabric_old_forward del module.attn2._fabric_old_forward diff --git a/scripts/weighted_attention.py b/scripts/weighted_attention.py index 3ca4770..08fd215 100644 --- a/scripts/weighted_attention.py +++ b/scripts/weighted_attention.py @@ -21,8 +21,24 @@ try: except ImportError: pass +try: + from ldm_patched.modules import model_management + has_webui_forge = True + print("[FABRIC] Detected WebUI Forge, running in compatibility mode.") +except ImportError: + has_webui_forge = False + def get_weighted_attn_fn(): + if has_webui_forge: + if model_management.xformers_enabled(): + return weighted_xformers_attention_forward + elif model_management.pytorch_attention_enabled(): + return weighted_scaled_dot_product_attention_forward + else: + print(f"[FABRIC] Warning: No attention method enabled. Falling back to split attention.") + return weighted_split_cross_attention_forward + method = sd_hijack.model_hijack.optimization_method if method is None: return weighted_split_cross_attention_forward @@ -283,7 +299,8 @@ def weighted_split_cross_attention_forward(self, x, context=None, mask=None, wei q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() with devices.without_autocast(disable=not shared.opts.upcast_attn): - k_in = k_in * self.scale + default_scale = (q_in.shape[-1] / h) ** -0.5 + k_in = k_in * getattr(self, "scale", default_scale) del context, x