v0.6.5 (fix compatibility with WebUI Forge)

pull/38/head
dvruette 2024-03-07 16:50:07 +01:00
parent 5d4335d59a
commit 19a3c72499
4 changed files with 51 additions and 26 deletions

View File

@ -10,7 +10,8 @@ ComfyUI node (by [@ssitu](https://github.com/ssitu)): https://github.com/ssitu/C
![demo](static/fabric_demo.gif)
## Releases
## Releases and Changelog
- [07.03.2024] 🔨 v0.6.5: Fixes compatibility with [WebUI Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge).
- [07.03.2024] ✨ v0.6.4: SDXL support has been added. For optimal results, lowering the feedback strength is recommended (0.5 seems to be a good starting point).
- [29.08.2023] 🏎️ v0.6.0: Up to 2x faster and 4x less VRAM usage thanks to [Token Merging](https://github.com/dbolya/tomesd/tree/main) (tested with 16 feedback images and a batch size of 4), moderate gains for fewer feedback images (10% speedup for 2 images, 30% for 8 images). Enable the Token Merging option to take advantage of this.
- [22.08.2023] 🗃️ v0.5.0: Adds support for presets. Makes generated images using FABRIC more reproducible by loading the correct (previously used) feedback images when using "send to text2img/img2img".

View File

@ -27,7 +27,7 @@ except ImportError:
from modules.ui import create_refresh_button
__version__ = "0.6.4"
__version__ = "0.6.5"
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")

View File

@ -17,6 +17,13 @@ from scripts.fabric_utils import image_hash
from scripts.weighted_attention import weighted_attention
from scripts.merging import compute_merge
try:
import ldm_patched
has_webui_forge = True
print("[FABRIC] Detected WebUI Forge, running in compatibility mode.")
except ImportError:
has_webui_forge = False
SD15 = "sd15"
SDXL = "sdxl"
@ -43,10 +50,12 @@ def encode_to_latent(p, image, w, h):
return z.squeeze(0)
def forward_noise(p, x_0, t, noise=None):
device = x_0.device
if noise is None:
noise = torch.randn_like(x_0)
sqrt_alpha_bar_t = extract_into_tensor(p.sd_model.alphas_cumprod.sqrt(), t, x_0.shape)
sqrt_one_minus_alpha_bar_t = extract_into_tensor((1.0 - p.sd_model.alphas_cumprod).sqrt(), t, x_0.shape)
alpha_bar = p.sd_model.alphas_cumprod.to(device)
sqrt_alpha_bar_t = extract_into_tensor(alpha_bar.sqrt(), t, x_0.shape)
sqrt_one_minus_alpha_bar_t = extract_into_tensor((1.0 - alpha_bar).sqrt(), t, x_0.shape)
x_t = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise
return x_t
@ -96,13 +105,19 @@ def patch_unet_forward_pass(p, unet, params):
if isinstance(p.sd_model, LatentDiffusion):
sd_version = SD15
num_timesteps = p.sd_model.num_timesteps
BasicTransformerBlock = ldm.modules.attention.BasicTransformerBlock
elif isinstance(p.sd_model, DiffusionEngine):
sd_version = SDXL
num_timesteps = len(p.sd_model.alphas_cumprod)
BasicTransformerBlock = sgm.modules.attention.BasicTransformerBlock
else:
raise ValueError(f"[FABRIC] Unsupported SD model: {type(p.sd_model)}")
transformer_block_type = tuple(
[
ldm.modules.attention.BasicTransformerBlock, # SD 1.5
sgm.modules.attention.BasicTransformerBlock, # SDXL
]
+ ([ldm_patched.ldm.modules.attention.BasicTransformerBlock] if has_webui_forge else [])
)
batch_size = p.batch_size
@ -173,20 +188,13 @@ def patch_unet_forward_pass(p, unet, params):
return self._fabric_old_forward(x, timesteps, context, **kwargs)
# add noise to reference latents
if sd_version == SD15:
all_zs = []
for latent in all_latents:
z = p.sd_model.q_sample(latent.unsqueeze(0), timesteps[0].unsqueeze(0))[0]
all_zs.append(z)
all_zs = torch.stack(all_zs, dim=0)
else: # SDXL
xs_0 = torch.stack(all_latents, dim=0)
ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,)
all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long())
xs_0 = torch.stack(all_latents, dim=0)
ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,)
all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long())
# save original forward pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"):
if isinstance(module, transformer_block_type) and not hasattr(module.attn1, "_fabric_old_forward"):
module.attn1._fabric_old_forward = module.attn1.forward
module.attn2._fabric_old_forward = module.attn2.forward
@ -214,7 +222,7 @@ def patch_unet_forward_pass(p, unet, params):
# patch forward pass to cache hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
if isinstance(module, transformer_block_type):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2)
layer_idx += 1
@ -260,20 +268,19 @@ def patch_unet_forward_pass(p, unet, params):
weights = None
return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim)
outs = []
out = torch.zeros_like(x)
if num_cond > 0:
out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim)
outs.append(out_cond)
out[cond_ids] = out_cond
if num_uncond > 0:
out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim)
outs.append(out_uncond)
out = torch.cat(outs, dim=0)
out[uncond_ids] = out_uncond
return out
# patch forward pass to inject cached hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
if isinstance(module, transformer_block_type):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
layer_idx += 1
@ -283,10 +290,10 @@ def patch_unet_forward_pass(p, unet, params):
finally:
# restore original pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"):
if isinstance(module, transformer_block_type) and hasattr(module.attn1, "_fabric_old_forward"):
module.attn1.forward = module.attn1._fabric_old_forward
del module.attn1._fabric_old_forward
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn2, "_fabric_old_forward"):
if isinstance(module, transformer_block_type) and hasattr(module.attn2, "_fabric_old_forward"):
module.attn2.forward = module.attn2._fabric_old_forward
del module.attn2._fabric_old_forward

View File

@ -21,8 +21,24 @@ try:
except ImportError:
pass
try:
from ldm_patched.modules import model_management
has_webui_forge = True
print("[FABRIC] Detected WebUI Forge, running in compatibility mode.")
except ImportError:
has_webui_forge = False
def get_weighted_attn_fn():
if has_webui_forge:
if model_management.xformers_enabled():
return weighted_xformers_attention_forward
elif model_management.pytorch_attention_enabled():
return weighted_scaled_dot_product_attention_forward
else:
print(f"[FABRIC] Warning: No attention method enabled. Falling back to split attention.")
return weighted_split_cross_attention_forward
method = sd_hijack.model_hijack.optimization_method
if method is None:
return weighted_split_cross_attention_forward
@ -283,7 +299,8 @@ def weighted_split_cross_attention_forward(self, x, context=None, mask=None, wei
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k_in = k_in * self.scale
default_scale = (q_in.shape[-1] / h) ** -0.5
k_in = k_in * getattr(self, "scale", default_scale)
del context, x