v0.6.5 (fix compatibility with WebUI Forge)
parent
5d4335d59a
commit
19a3c72499
|
|
@ -10,7 +10,8 @@ ComfyUI node (by [@ssitu](https://github.com/ssitu)): https://github.com/ssitu/C
|
|||
|
||||

|
||||
|
||||
## Releases
|
||||
## Releases and Changelog
|
||||
- [07.03.2024] 🔨 v0.6.5: Fixes compatibility with [WebUI Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge).
|
||||
- [07.03.2024] ✨ v0.6.4: SDXL support has been added. For optimal results, lowering the feedback strength is recommended (0.5 seems to be a good starting point).
|
||||
- [29.08.2023] 🏎️ v0.6.0: Up to 2x faster and 4x less VRAM usage thanks to [Token Merging](https://github.com/dbolya/tomesd/tree/main) (tested with 16 feedback images and a batch size of 4), moderate gains for fewer feedback images (10% speedup for 2 images, 30% for 8 images). Enable the Token Merging option to take advantage of this.
|
||||
- [22.08.2023] 🗃️ v0.5.0: Adds support for presets. Makes generated images using FABRIC more reproducible by loading the correct (previously used) feedback images when using "send to text2img/img2img".
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ except ImportError:
|
|||
from modules.ui import create_refresh_button
|
||||
|
||||
|
||||
__version__ = "0.6.4"
|
||||
__version__ = "0.6.5"
|
||||
|
||||
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,13 @@ from scripts.fabric_utils import image_hash
|
|||
from scripts.weighted_attention import weighted_attention
|
||||
from scripts.merging import compute_merge
|
||||
|
||||
try:
|
||||
import ldm_patched
|
||||
has_webui_forge = True
|
||||
print("[FABRIC] Detected WebUI Forge, running in compatibility mode.")
|
||||
except ImportError:
|
||||
has_webui_forge = False
|
||||
|
||||
|
||||
SD15 = "sd15"
|
||||
SDXL = "sdxl"
|
||||
|
|
@ -43,10 +50,12 @@ def encode_to_latent(p, image, w, h):
|
|||
return z.squeeze(0)
|
||||
|
||||
def forward_noise(p, x_0, t, noise=None):
|
||||
device = x_0.device
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x_0)
|
||||
sqrt_alpha_bar_t = extract_into_tensor(p.sd_model.alphas_cumprod.sqrt(), t, x_0.shape)
|
||||
sqrt_one_minus_alpha_bar_t = extract_into_tensor((1.0 - p.sd_model.alphas_cumprod).sqrt(), t, x_0.shape)
|
||||
alpha_bar = p.sd_model.alphas_cumprod.to(device)
|
||||
sqrt_alpha_bar_t = extract_into_tensor(alpha_bar.sqrt(), t, x_0.shape)
|
||||
sqrt_one_minus_alpha_bar_t = extract_into_tensor((1.0 - alpha_bar).sqrt(), t, x_0.shape)
|
||||
x_t = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise
|
||||
return x_t
|
||||
|
||||
|
|
@ -96,13 +105,19 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
if isinstance(p.sd_model, LatentDiffusion):
|
||||
sd_version = SD15
|
||||
num_timesteps = p.sd_model.num_timesteps
|
||||
BasicTransformerBlock = ldm.modules.attention.BasicTransformerBlock
|
||||
elif isinstance(p.sd_model, DiffusionEngine):
|
||||
sd_version = SDXL
|
||||
num_timesteps = len(p.sd_model.alphas_cumprod)
|
||||
BasicTransformerBlock = sgm.modules.attention.BasicTransformerBlock
|
||||
else:
|
||||
raise ValueError(f"[FABRIC] Unsupported SD model: {type(p.sd_model)}")
|
||||
|
||||
transformer_block_type = tuple(
|
||||
[
|
||||
ldm.modules.attention.BasicTransformerBlock, # SD 1.5
|
||||
sgm.modules.attention.BasicTransformerBlock, # SDXL
|
||||
]
|
||||
+ ([ldm_patched.ldm.modules.attention.BasicTransformerBlock] if has_webui_forge else [])
|
||||
)
|
||||
|
||||
batch_size = p.batch_size
|
||||
|
||||
|
|
@ -173,20 +188,13 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
|
||||
# add noise to reference latents
|
||||
if sd_version == SD15:
|
||||
all_zs = []
|
||||
for latent in all_latents:
|
||||
z = p.sd_model.q_sample(latent.unsqueeze(0), timesteps[0].unsqueeze(0))[0]
|
||||
all_zs.append(z)
|
||||
all_zs = torch.stack(all_zs, dim=0)
|
||||
else: # SDXL
|
||||
xs_0 = torch.stack(all_latents, dim=0)
|
||||
ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,)
|
||||
all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long())
|
||||
xs_0 = torch.stack(all_latents, dim=0)
|
||||
ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,)
|
||||
all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long())
|
||||
|
||||
# save original forward pass
|
||||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"):
|
||||
if isinstance(module, transformer_block_type) and not hasattr(module.attn1, "_fabric_old_forward"):
|
||||
module.attn1._fabric_old_forward = module.attn1.forward
|
||||
module.attn2._fabric_old_forward = module.attn2.forward
|
||||
|
||||
|
|
@ -214,7 +222,7 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
# patch forward pass to cache hidden states
|
||||
layer_idx = 0
|
||||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock):
|
||||
if isinstance(module, transformer_block_type):
|
||||
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
|
||||
module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2)
|
||||
layer_idx += 1
|
||||
|
|
@ -260,20 +268,19 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
weights = None
|
||||
return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim)
|
||||
|
||||
outs = []
|
||||
out = torch.zeros_like(x)
|
||||
if num_cond > 0:
|
||||
out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim)
|
||||
outs.append(out_cond)
|
||||
out[cond_ids] = out_cond
|
||||
if num_uncond > 0:
|
||||
out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim)
|
||||
outs.append(out_uncond)
|
||||
out = torch.cat(outs, dim=0)
|
||||
out[uncond_ids] = out_uncond
|
||||
return out
|
||||
|
||||
# patch forward pass to inject cached hidden states
|
||||
layer_idx = 0
|
||||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock):
|
||||
if isinstance(module, transformer_block_type):
|
||||
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
|
||||
layer_idx += 1
|
||||
|
||||
|
|
@ -283,10 +290,10 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
finally:
|
||||
# restore original pass
|
||||
for module in self.modules():
|
||||
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"):
|
||||
if isinstance(module, transformer_block_type) and hasattr(module.attn1, "_fabric_old_forward"):
|
||||
module.attn1.forward = module.attn1._fabric_old_forward
|
||||
del module.attn1._fabric_old_forward
|
||||
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn2, "_fabric_old_forward"):
|
||||
if isinstance(module, transformer_block_type) and hasattr(module.attn2, "_fabric_old_forward"):
|
||||
module.attn2.forward = module.attn2._fabric_old_forward
|
||||
del module.attn2._fabric_old_forward
|
||||
|
||||
|
|
|
|||
|
|
@ -21,8 +21,24 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from ldm_patched.modules import model_management
|
||||
has_webui_forge = True
|
||||
print("[FABRIC] Detected WebUI Forge, running in compatibility mode.")
|
||||
except ImportError:
|
||||
has_webui_forge = False
|
||||
|
||||
|
||||
def get_weighted_attn_fn():
|
||||
if has_webui_forge:
|
||||
if model_management.xformers_enabled():
|
||||
return weighted_xformers_attention_forward
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
return weighted_scaled_dot_product_attention_forward
|
||||
else:
|
||||
print(f"[FABRIC] Warning: No attention method enabled. Falling back to split attention.")
|
||||
return weighted_split_cross_attention_forward
|
||||
|
||||
method = sd_hijack.model_hijack.optimization_method
|
||||
if method is None:
|
||||
return weighted_split_cross_attention_forward
|
||||
|
|
@ -283,7 +299,8 @@ def weighted_split_cross_attention_forward(self, x, context=None, mask=None, wei
|
|||
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k_in = k_in * self.scale
|
||||
default_scale = (q_in.shape[-1] / h) ** -0.5
|
||||
k_in = k_in * getattr(self, "scale", default_scale)
|
||||
|
||||
del context, x
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue