339 lines
15 KiB
Python
339 lines
15 KiB
Python
import functools
|
|
|
|
import torch
|
|
import torchvision.transforms.functional as functional
|
|
|
|
from modules import devices, images, shared
|
|
from modules.processing import StableDiffusionProcessingTxt2Img
|
|
|
|
import ldm.modules.attention
|
|
import sgm.modules.attention
|
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
|
from ldm.models.diffusion.ddpm import extract_into_tensor
|
|
from sgm.models.diffusion import DiffusionEngine
|
|
|
|
from scripts.marking import apply_marking_patch, unmark_prompt_context
|
|
from scripts.fabric_utils import image_hash
|
|
from scripts.weighted_attention import weighted_attention
|
|
from scripts.merging import compute_merge
|
|
|
|
try:
|
|
import ldm_patched.ldm.modules.attention
|
|
has_webui_forge = True
|
|
print("[FABRIC] Detected WebUI Forge, running in compatibility mode.")
|
|
except ImportError:
|
|
has_webui_forge = False
|
|
|
|
|
|
SD15 = "sd15"
|
|
SDXL = "sdxl"
|
|
|
|
|
|
def encode_to_latent(p, image, w, h):
|
|
image = images.resize_image(1, image, w, h)
|
|
x = functional.pil_to_tensor(image)
|
|
x = functional.center_crop(x, (w, h)) # just to be safe
|
|
x = x.to(devices.device, dtype=devices.dtype_vae)
|
|
x = ((x / 255.0) * 2.0 - 1.0).unsqueeze(0)
|
|
|
|
# TODO: use caching to make this faster
|
|
with devices.autocast():
|
|
vae_output = p.sd_model.encode_first_stage(x)
|
|
z = p.sd_model.get_first_stage_encoding(vae_output)
|
|
if torch.isnan(z).any():
|
|
print(f"[FABRIC] NaNs in VAE output found, retrying with 32-bit precision. To always start with 32-bit VAE, use --no-half-vae commandline flag.")
|
|
devices.dtype_vae = torch.float32
|
|
x = x.to(devices.dtype_vae)
|
|
p.sd_model.first_stage_model.to(devices.dtype_vae)
|
|
vae_output = p.sd_model.encode_first_stage(x)
|
|
z = p.sd_model.get_first_stage_encoding(vae_output)
|
|
z = z.to(devices.dtype_unet)
|
|
return z.squeeze(0)
|
|
|
|
def forward_noise(p, x_0, t, noise=None):
|
|
device = x_0.device
|
|
if noise is None:
|
|
noise = torch.randn_like(x_0)
|
|
alpha_bar = p.sd_model.alphas_cumprod.to(device)
|
|
sqrt_alpha_bar_t = extract_into_tensor(alpha_bar.sqrt(), t, x_0.shape)
|
|
sqrt_one_minus_alpha_bar_t = extract_into_tensor((1.0 - alpha_bar).sqrt(), t, x_0.shape)
|
|
x_t = sqrt_alpha_bar_t * x_0 + sqrt_one_minus_alpha_bar_t * noise
|
|
return x_t
|
|
|
|
|
|
def get_latents_from_params(p, params, width, height):
|
|
w, h = (width // 8) * 8, (height // 8) * 8
|
|
w_latent, h_latent = width // 8, height // 8
|
|
|
|
def get_latents(images, cached_latents=None):
|
|
# check if latents need to be computed or recomputed (if image size changed e.g. due to high-res fix)
|
|
if cached_latents is None:
|
|
cached_latents = {}
|
|
|
|
latents = []
|
|
for img in images:
|
|
img_hash = image_hash(img)
|
|
if img_hash not in cached_latents:
|
|
cached_latents[img_hash] = encode_to_latent(p, img, w, h)
|
|
elif cached_latents[img_hash].shape[-2:] != (w_latent, h_latent):
|
|
print(f"[FABRIC] Recomputing latent for image of size {img.size}")
|
|
cached_latents[img_hash] = encode_to_latent(p, img, w, h)
|
|
latents.append(cached_latents[img_hash])
|
|
return latents, cached_latents
|
|
|
|
params.pos_latents, params.pos_latent_cache = get_latents(params.pos_images, params.pos_latent_cache)
|
|
params.neg_latents, params.neg_latent_cache = get_latents(params.neg_images, params.neg_latent_cache)
|
|
return params.pos_latents, params.neg_latents
|
|
|
|
|
|
def get_curr_feedback_weight(p, params, timestep, num_timesteps=1000):
|
|
progress = 1 - (timestep / (num_timesteps - 1))
|
|
if progress >= params.start and progress <= params.end:
|
|
w = params.max_weight
|
|
else:
|
|
w = params.min_weight
|
|
return max(0, w), max(0, w * params.neg_scale)
|
|
|
|
|
|
def patch_unet_forward_pass(p, unet, params):
|
|
if not params.pos_images and not params.neg_images:
|
|
print("[FABRIC] No feedback images found, aborting patching")
|
|
return
|
|
|
|
if not hasattr(unet, "_fabric_old_forward"):
|
|
unet._fabric_old_forward = unet.forward
|
|
|
|
if isinstance(p.sd_model, LatentDiffusion):
|
|
sd_version = SD15
|
|
num_timesteps = p.sd_model.num_timesteps
|
|
elif isinstance(p.sd_model, DiffusionEngine):
|
|
sd_version = SDXL
|
|
num_timesteps = len(p.sd_model.alphas_cumprod)
|
|
else:
|
|
raise ValueError(f"[FABRIC] Unsupported SD model: {type(p.sd_model)}")
|
|
|
|
transformer_block_type = tuple(
|
|
[
|
|
ldm.modules.attention.BasicTransformerBlock, # SD 1.5
|
|
sgm.modules.attention.BasicTransformerBlock, # SDXL
|
|
]
|
|
+ ([ldm_patched.ldm.modules.attention.BasicTransformerBlock] if has_webui_forge else [])
|
|
)
|
|
|
|
batch_size = p.batch_size
|
|
|
|
null_ctx = p.sd_model.get_learned_conditioning([""])
|
|
if isinstance(null_ctx, torch.Tensor): # SD1.5
|
|
null_ctx = null_ctx.to(devices.device, dtype=devices.dtype_unet)
|
|
elif isinstance(null_ctx, dict): # SDXL
|
|
for key in null_ctx:
|
|
if isinstance(null_ctx[key], torch.Tensor):
|
|
null_ctx[key] = null_ctx[key].to(devices.device, dtype=devices.dtype_unet)
|
|
else:
|
|
raise ValueError(f"[FABRIC] Unsupported context type: {type(null_ctx)}")
|
|
|
|
width = (p.width // 8) * 8
|
|
height = (p.height // 8) * 8
|
|
|
|
has_hires_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
|
|
if has_hires_fix:
|
|
if p.hr_resize_x == 0 and p.hr_resize_y == 0:
|
|
hr_w = int(p.width * p.hr_scale)
|
|
hr_h = int(p.height * p.hr_scale)
|
|
else:
|
|
hr_w, hr_h = p.hr_resize_x, p.hr_resize_y
|
|
hr_w = (hr_w // 8) * 8
|
|
hr_h = (hr_h // 8) * 8
|
|
else:
|
|
hr_w = width
|
|
hr_h = height
|
|
|
|
tome_args = {
|
|
"enabled": params.tome_enabled,
|
|
"sx": 2, "sy": 2,
|
|
"use_rand": True,
|
|
"generator": None,
|
|
"seed": params.tome_seed,
|
|
}
|
|
|
|
prev_vals = {
|
|
"weight_modifier": 1.0,
|
|
}
|
|
|
|
def new_forward(self, x, timesteps=None, context=None, **kwargs):
|
|
_, uncond_ids, cond_ids, context = unmark_prompt_context(context)
|
|
has_cond = len(cond_ids) > 0
|
|
has_uncond = len(uncond_ids) > 0
|
|
|
|
h_latent, w_latent = x.shape[-2:]
|
|
w, h = 8 * w_latent, 8 * h_latent
|
|
if has_hires_fix and w == hr_w and h == hr_h:
|
|
if not params.feedback_during_high_res_fix:
|
|
print("[FABRIC] Skipping feedback during high-res fix")
|
|
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
|
|
|
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item(), num_timesteps=num_timesteps)
|
|
if pos_weight <= 0 and neg_weight <= 0:
|
|
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
|
|
|
if params.burnout_protection and "cond" in prev_vals and "uncond" in prev_vals:
|
|
# burnout protection: if the difference betwen cond/uncond was too high in the previous step (sign of instability), slash the weight modifier
|
|
diff_std = (prev_vals["cond"] - prev_vals["uncond"]).std(dim=(2, 3)).max().item()
|
|
diff_abs_mean = (prev_vals["cond"] - prev_vals["uncond"]).mean(dim=(2, 3)).abs().max().item()
|
|
if diff_std > 0.06 or diff_abs_mean > 0.02:
|
|
prev_vals["weight_modifier"] *= 0.5
|
|
else:
|
|
prev_vals["weight_modifier"] = min(1.0, 1.5 * prev_vals["weight_modifier"])
|
|
|
|
pos_weight, neg_weight = pos_weight * prev_vals["weight_modifier"], neg_weight * prev_vals["weight_modifier"]
|
|
|
|
pos_latents, neg_latents = get_latents_from_params(p, params, w, h)
|
|
pos_latents = pos_latents if has_cond else []
|
|
neg_latents = neg_latents if has_uncond else []
|
|
all_latents = pos_latents + neg_latents
|
|
|
|
# Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU
|
|
if shared.cmd_opts.medvram:
|
|
try:
|
|
# Trigger register_forward_pre_hook to move the model to correct device
|
|
p.sd_model.model()
|
|
except:
|
|
pass
|
|
|
|
if len(all_latents) == 0:
|
|
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
|
|
|
# add noise to reference latents
|
|
xs_0 = torch.stack(all_latents, dim=0)
|
|
ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,)
|
|
all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long())
|
|
|
|
# save original forward pass
|
|
for module in self.modules():
|
|
if isinstance(module, transformer_block_type) and not hasattr(module.attn1, "_fabric_old_forward"):
|
|
module.attn1._fabric_old_forward = module.attn1.forward
|
|
module.attn2._fabric_old_forward = module.attn2.forward
|
|
|
|
try:
|
|
## cache hidden states
|
|
cached_hiddens = {}
|
|
def patched_attn1_forward(attn1, layer_idx, x, **kwargs):
|
|
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
|
|
x = merge(x)
|
|
if layer_idx not in cached_hiddens:
|
|
cached_hiddens[layer_idx] = x.detach().clone().cpu()
|
|
else:
|
|
cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0)
|
|
out = attn1._fabric_old_forward(x, **kwargs)
|
|
out = unmerge(out)
|
|
return out
|
|
|
|
def patched_attn2_forward(attn2, x, **kwargs):
|
|
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
|
|
x = merge(x)
|
|
out = attn2._fabric_old_forward(x, **kwargs)
|
|
out = unmerge(out)
|
|
return out
|
|
|
|
# patch forward pass to cache hidden states
|
|
layer_idx = 0
|
|
for module in self.modules():
|
|
if isinstance(module, transformer_block_type):
|
|
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
|
|
module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2)
|
|
layer_idx += 1
|
|
|
|
# run forward pass just to cache hidden states, output is discarded
|
|
for i in range(0, len(all_zs), batch_size):
|
|
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
|
|
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
|
|
# use the null prompt for pre-computing hidden states on feedback images
|
|
ctx_args = {}
|
|
if sd_version == SD15:
|
|
ctx_args["context"] = null_ctx.expand(zs.size(0), -1, -1) # (bs, seq_len, d_model)
|
|
else: # SDXL
|
|
ctx_args["context"] = null_ctx["crossattn"].expand(zs.size(0), -1, -1) # (bs, seq_len, d_model)
|
|
ctx_args["y"] = null_ctx["vector"].expand(zs.size(0), -1) # (bs, d_vector)
|
|
_ = self._fabric_old_forward(zs, ts, **ctx_args)
|
|
|
|
num_pos = len(pos_latents)
|
|
num_neg = len(neg_latents)
|
|
num_cond = len(cond_ids)
|
|
num_uncond = len(uncond_ids)
|
|
tome_h_latent = h_latent * (1 - params.tome_ratio)
|
|
|
|
def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):
|
|
if context is None:
|
|
context = x
|
|
|
|
cached_hs = cached_hiddens[idx].to(x.device)
|
|
|
|
d_model = x.shape[-1]
|
|
|
|
def attention_with_feedback(_x, context, feedback_hs, w):
|
|
num_xs, num_fb = _x.shape[0], feedback_hs.shape[0]
|
|
if num_fb > 0:
|
|
feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim)
|
|
merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens)
|
|
feedback_ctx = merge(feedback_ctx)
|
|
ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim)
|
|
weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,)
|
|
weights[_x.shape[1]:] = w
|
|
else:
|
|
ctx = context
|
|
weights = None
|
|
return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim)
|
|
|
|
out = torch.zeros_like(x, dtype=devices.dtype_unet)
|
|
if num_cond > 0:
|
|
out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim)
|
|
out[cond_ids] = out_cond
|
|
if num_uncond > 0:
|
|
out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim)
|
|
out[uncond_ids] = out_uncond
|
|
return out
|
|
|
|
# patch forward pass to inject cached hidden states
|
|
layer_idx = 0
|
|
for module in self.modules():
|
|
if isinstance(module, transformer_block_type):
|
|
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
|
|
layer_idx += 1
|
|
|
|
# run forward pass with cached hidden states
|
|
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
|
|
|
|
cond_outs = out[cond_ids]
|
|
uncond_outs = out[uncond_ids]
|
|
|
|
if has_cond:
|
|
prev_vals["cond"] = cond_outs.detach().clone()
|
|
if has_uncond:
|
|
prev_vals["uncond"] = uncond_outs.detach().clone()
|
|
|
|
if params.burnout_protection:
|
|
# burnout protection: recenter the output to prevent instabilities caused by mean drift
|
|
mean = out.mean(dim=(2, 3), keepdim=True)
|
|
out = out - 0.5 * mean
|
|
finally:
|
|
# restore original pass
|
|
for module in self.modules():
|
|
if isinstance(module, transformer_block_type) and hasattr(module.attn1, "_fabric_old_forward"):
|
|
module.attn1.forward = module.attn1._fabric_old_forward
|
|
del module.attn1._fabric_old_forward
|
|
if isinstance(module, transformer_block_type) and hasattr(module.attn2, "_fabric_old_forward"):
|
|
module.attn2.forward = module.attn2._fabric_old_forward
|
|
del module.attn2._fabric_old_forward
|
|
|
|
return out
|
|
|
|
unet.forward = new_forward.__get__(unet)
|
|
|
|
apply_marking_patch(p)
|
|
|
|
def unpatch_unet_forward_pass(unet):
|
|
if hasattr(unet, "_fabric_old_forward"):
|
|
print("[FABRIC] Restoring original U-Net forward pass")
|
|
unet.forward = unet._fabric_old_forward
|
|
del unet._fabric_old_forward
|