import functools import os import torch import torchvision.transforms.functional as functional import matplotlib.pyplot as plt import numpy as np from modules import devices, images, shared from modules.processing import StableDiffusionProcessingTxt2Img from ldm.modules.attention import BasicTransformerBlock from scripts.marking import apply_marking_patch, unmark_prompt_context from scripts.fabric_utils import image_hash from scripts.weighted_attention import weighted_attention from scripts.merging import compute_merge def encode_to_latent(p, image, w, h): image = images.resize_image(1, image, w, h) x = functional.pil_to_tensor(image) x = functional.center_crop(x, (w, h)) # just to be safe x = x.to(devices.device, dtype=devices.dtype_vae) x = ((x / 255.0) * 2.0 - 1.0).unsqueeze(0) # TODO: use caching to make this faster with devices.autocast(): vae_output = p.sd_model.encode_first_stage(x) z = p.sd_model.get_first_stage_encoding(vae_output) return z.squeeze(0) def get_latents_from_params(p, params, width, height): w, h = (width // 8) * 8, (height // 8) * 8 w_latent, h_latent = width // 8, height // 8 def get_latents(images, cached_latents=None): # check if latents need to be computed or recomputed (if image size changed e.g. due to high-res fix) if cached_latents is None: cached_latents = {} latents = [] for img in images: img_hash = image_hash(img) if img_hash not in cached_latents: cached_latents[img_hash] = encode_to_latent(p, img, w, h) elif cached_latents[img_hash].shape[-2:] != (w_latent, h_latent): print(f"[FABRIC] Recomputing latent for image of size {img.size}") cached_latents[img_hash] = encode_to_latent(p, img, w, h) latents.append(cached_latents[img_hash]) return latents, cached_latents params.pos_latents, params.pos_latent_cache = get_latents(params.pos_images, params.pos_latent_cache) params.neg_latents, params.neg_latent_cache = get_latents(params.neg_images, params.neg_latent_cache) return params.pos_latents, params.neg_latents def get_curr_feedback_weight(p, params, timestep): progress = 1 - (timestep / (p.sd_model.num_timesteps - 1)) if progress >= params.start and progress <= params.end: w = params.max_weight else: w = params.min_weight return max(0, w), max(0, w * params.neg_scale) def patch_unet_forward_pass(p, unet, params): if not params.pos_images and not params.neg_images: print("[FABRIC] No feedback images found, aborting patching") return if not hasattr(unet, "_fabric_old_forward"): unet._fabric_old_forward = unet.forward batch_size = p.batch_size null_ctx = p.sd_model.get_learned_conditioning([""]).to(devices.device, dtype=devices.dtype_unet) width = (p.width // 8) * 8 height = (p.height // 8) * 8 has_hires_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) if has_hires_fix: if p.hr_resize_x == 0 and p.hr_resize_y == 0: hr_w = int(p.width * p.hr_scale) hr_h = int(p.height * p.hr_scale) else: hr_w, hr_h = p.hr_resize_x, p.hr_resize_y hr_w = (hr_w // 8) * 8 hr_h = (hr_h // 8) * 8 else: hr_w = width hr_h = height tome_args = { "enabled": params.tome_enabled, "sx": 2, "sy": 2, "use_rand": True, "generator": None, "seed": params.tome_seed, } x_means = [] x_stds = [] x_norms = [] cond_means = [] cond_stds = [] cond_norms = [] uncond_means = [] uncond_stds = [] uncond_norms = [] DO_FABRIC = True mean_ema = {} def new_forward(self, x, timesteps=None, context=None, **kwargs): _, uncond_ids, context = unmark_prompt_context(context) cond_ids = [i for i in range(context.size(0)) if i not in uncond_ids] has_cond = len(cond_ids) > 0 has_uncond = len(uncond_ids) > 0 def plot_hist(out, filename): plt.figure() xs = out.detach().cpu().numpy() for i in range(xs.shape[0]): hist, bin_edges = np.histogram(xs[i].reshape(-1), bins=100, density=True) bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2. plt.plot(bin_centers, hist) plt.savefig(filename) plt.close() def plot_lines(ys, filename): plt.figure() ys = np.stack(ys, axis=1) for i in range(ys.shape[0]): plt.plot(ys[i].reshape(-1)) plt.savefig(filename) plt.close() if not DO_FABRIC: out = self._fabric_old_forward(x, timesteps, context, **kwargs) else: h_latent, w_latent = x.shape[-2:] w, h = 8 * w_latent, 8 * h_latent if has_hires_fix and w == hr_w and h == hr_h: if not params.feedback_during_high_res_fix: print("[FABRIC] Skipping feedback during high-res fix") return self._fabric_old_forward(x, timesteps, context, **kwargs) pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item()) if pos_weight <= 0 and neg_weight <= 0: return self._fabric_old_forward(x, timesteps, context, **kwargs) pos_latents, neg_latents = get_latents_from_params(p, params, w, h) pos_latents = pos_latents if has_cond else [] neg_latents = neg_latents if has_uncond else [] all_latents = pos_latents + neg_latents # Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU if shared.cmd_opts.medvram: try: # Trigger register_forward_pre_hook to move the model to correct device p.sd_model.model() except: pass if len(all_latents) == 0: return self._fabric_old_forward(x, timesteps, context, **kwargs) # ## intervention 2: std decay # std = x.std(dim=(2, 3), keepdim=True) # mask = (x.std(dim=(2, 3), keepdim=True) > 1.0).float() # x = mask * ((0.05*(std - 1) + 1) * x / std) + (1 - mask) * x # ## intervention 3: mean clamp # relative_t = timesteps[0].item() / (p.sd_model.num_timesteps - 1) # max_mean = 0.5 * (1 - relative_t) # mean = x.mean(dim=(2, 3), keepdim=True) # x = x - mean + torch.clamp(mean, min=-max_mean, max=max_mean) # ## intervention 7: mean decay # mean = x.mean(dim=(2, 3), keepdim=True) # x = x - 0.2 * mean # add noise to reference latents all_zs = [] for latent in all_latents: z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0] all_zs.append(z) all_zs = torch.stack(all_zs, dim=0) # save original forward pass for module in self.modules(): if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"): module.attn1._fabric_old_forward = module.attn1.forward module.attn2._fabric_old_forward = module.attn2.forward try: ## cache hidden states cached_hiddens = {} def patched_attn1_forward(attn1, layer_idx, x, **kwargs): merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio) x = merge(x) if layer_idx not in cached_hiddens: cached_hiddens[layer_idx] = x.detach().clone().cpu() else: cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0) out = attn1._fabric_old_forward(x, **kwargs) out = unmerge(out) return out def patched_attn2_forward(attn2, x, **kwargs): merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio) x = merge(x) out = attn2._fabric_old_forward(x, **kwargs) out = unmerge(out) return out # patch forward pass to cache hidden states layer_idx = 0 for module in self.modules(): if isinstance(module, BasicTransformerBlock): module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2) layer_idx += 1 # run forward pass just to cache hidden states, output is discarded for i in range(0, len(all_zs), batch_size): zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype) ts = timesteps[:1].expand(zs.size(0)) # (bs,) # use the null prompt for pre-computing hidden states on feedback images ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim) _ = self._fabric_old_forward(zs, ts, ctx) num_pos = len(pos_latents) num_neg = len(neg_latents) num_cond = len(cond_ids) num_uncond = len(uncond_ids) tome_h_latent = h_latent * (1 - params.tome_ratio) def patched_attn1_forward(attn1, idx, x, context=None, **kwargs): if context is None: context = x cached_hs = cached_hiddens[idx].to(x.device) d_model = x.shape[-1] def attention_with_feedback(_x, context, feedback_hs, w): num_xs, num_fb = _x.shape[0], feedback_hs.shape[0] if num_fb > 0: feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim) merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens) feedback_ctx = merge(feedback_ctx) ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim) weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,) weights[_x.shape[1]:] = w else: ctx = context weights = None return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim) outs = [] if num_cond > 0: out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim) outs.append(out_cond) if num_uncond > 0: out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim) outs.append(out_uncond) out = torch.cat(outs, dim=0) return out # patch forward pass to inject cached hidden states layer_idx = 0 for module in self.modules(): if isinstance(module, BasicTransformerBlock): module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) layer_idx += 1 # run forward pass with cached hidden states out = self._fabric_old_forward(x, timesteps, context, **kwargs) finally: # restore original pass for module in self.modules(): if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"): module.attn1.forward = module.attn1._fabric_old_forward del module.attn1._fabric_old_forward if isinstance(module, BasicTransformerBlock) and hasattr(module.attn2, "_fabric_old_forward"): module.attn2.forward = module.attn2._fabric_old_forward del module.attn2._fabric_old_forward cond_outs = out[cond_ids] uncond_outs = out[uncond_ids] xs = x[cond_ids] t = int(timesteps[0].item()) seed = p.seed x_means.append(xs.mean(dim=(2, 3)).view(-1).cpu().numpy()) x_stds.append(xs.std(dim=(2, 3)).view(-1).cpu().numpy()) x_norms.append(xs.norm(dim=(2, 3)).view(-1).cpu().numpy()) cond_means.append(cond_outs.mean(dim=(2, 3)).view(-1).cpu().numpy()) cond_stds.append(cond_outs.std(dim=(2, 3)).view(-1).cpu().numpy()) cond_norms.append(cond_outs.norm(dim=(2, 3)).view(-1).cpu().numpy()) uncond_means.append(uncond_outs.mean(dim=(2, 3)).view(-1).cpu().numpy()) uncond_stds.append(uncond_outs.std(dim=(2, 3)).view(-1).cpu().numpy()) uncond_norms.append(uncond_outs.norm(dim=(2, 3)).view(-1).cpu().numpy()) def plot_trajectory(means, stds, ax=None): if ax is None: ax = plt.gca() # means = np.stack(means, axis=1) # stds = np.stack(stds, axis=1) ax.plot(means) ax.fill_between(range(means.shape[0]), means - stds, means + stds, alpha=0.3) def select(xs, batch_idx): return np.stack([x[batch_idx] for x in xs], axis=0) for i in range(len(xs)): x_means_i = select(x_means, i) x_stds_i = select(x_stds, i) cond_means_i = select(cond_means, i) cond_stds_i = select(cond_stds, i) uncond_means_i = select(uncond_means, i) uncond_stds_i = select(uncond_stds, i) plot_file = f"plots/stats/{seed + i}_{'fabric' if DO_FABRIC else 'default'}.png" fig, axs = plt.subplots(1, 3, figsize=(15, 5)) plot_trajectory(x_means_i, x_stds_i, ax=axs[0]) plot_trajectory(cond_means_i, cond_stds_i, ax=axs[1]) plot_trajectory(uncond_means_i, uncond_stds_i, ax=axs[2]) fig.savefig(plot_file) # os.makedirs("plots/lines", exist_ok=True) # plot_lines(x_means, f"plots/lines/{seed}_x_means{'_fabric' if DO_FABRIC else ''}.png") # plot_lines(x_stds, f"plots/lines/{seed}_x_stds{'_fabric' if DO_FABRIC else ''}.png") # plot_lines(x_norms, f"plots/lines/{seed}_x_norms{'_fabric' if DO_FABRIC else ''}.png") # plot_lines(cond_means, f"plots/lines/{seed}_cond_means{'_fabric' if DO_FABRIC else ''}.png") # plot_lines(cond_stds, f"plots/lines/{seed}_cond_stds{'_fabric' if DO_FABRIC else ''}.png") # plot_lines(cond_norms, f"plots/lines/{seed}_cond_norms{'_fabric' if DO_FABRIC else ''}.png") # plot_lines(uncond_means, f"plots/lines/{seed}_uncond_means{'_fabric' if DO_FABRIC else ''}.png") # plot_lines(uncond_stds, f"plots/lines/{seed}_uncond_stds{'_fabric' if DO_FABRIC else ''}.png") # plot_lines(uncond_norms, f"plots/lines/{seed}_uncond_norms{'_fabric' if DO_FABRIC else ''}.png") # os.makedirs("plots/hists", exist_ok=True) # for i in range(len(xs)): # filename = f"plots/hists/{seed}_{t}_x_{i}{'_fabric' if DO_FABRIC else ''}.png" # plot_hist(x[i], filename) # for i in range(len(cond_outs)): # filename = f"plots/hists/{seed}_{t}_cond_{i}{'_fabric' if DO_FABRIC else ''}.png" # plot_hist(cond_outs[i], filename) # for i in range(len(uncond_outs)): # filename = f"plots/hists/{seed}_{t}_uncond_{i}{'_fabric' if DO_FABRIC else ''}.png" # plot_hist(uncond_outs[i], filename) # ## intervention 1: mean decay # out = out - 0.5*out.mean(dim=(2, 3), keepdim=True) # ## intervention 4: early normalization # relative_t = 1 - timesteps[0].item() / (p.sd_model.num_timesteps - 1) # if relative_t < 0.33: # out = (out - out.mean(dim=(2, 3), keepdim=True)) / out.std(dim=(2, 3), keepdim=True) # ## intervention 5: momentum mean decay # dampen = 0.8 # accel = 0.35 # alpha = 0.5 # beta = 0.25 # mean = out.mean(dim=(2, 3), keepdim=True) # if "momentum" not in mean_ema: # mean_ema["momentum"] = mean # else: # mean_ema["momentum"] = dampen * (mean_ema["momentum"] + accel * out.mean(dim=(2, 3), keepdim=True)) # out = out - alpha * mean - beta * mean_ema["momentum"] # print() # print("mean: ", mean.cpu().view(-1)) # print("momentum:", mean_ema["momentum"].cpu().view(-1)) # plot_lines([p.sd_model.betas.cpu().unsqueeze(0).numpy()], f"plots/lines/betas.png") # plot_lines([p.sd_model.alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/alphas_cumprod.png") # plot_lines([p.sd_model.alphas_cumprod_prev.cpu().unsqueeze(0).numpy()], f"plots/lines/alphas_cumprod_prev.png") # plot_lines([p.sd_model.posterior_variance.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_variance.png") # plot_lines([p.sd_model.posterior_log_variance_clipped.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_log_variance_clipped.png") # plot_lines([p.sd_model.posterior_mean_coef1.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_mean_coef1.png") # plot_lines([p.sd_model.posterior_mean_coef2.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_mean_coef2.png") # plot_lines([p.sd_model.sqrt_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_alphas_cumprod.png") # plot_lines([p.sd_model.sqrt_one_minus_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_one_minus_alphas_cumprod.png") # plot_lines([p.sd_model.log_one_minus_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/log_one_minus_alphas_cumprod.png") # plot_lines([p.sd_model.sqrt_recip_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_recip_alphas_cumprod.png") # plot_lines([p.sd_model.sqrt_recipm1_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_recipm1_alphas_cumprod.png") # ## intervention 8: dynamic standardization # # TODO: test how to dynamically bound the mean rather than always subtracting some fraction # relative_t = 1 - timesteps[0].item() / (p.sd_model.num_timesteps - 1) # if relative_t < 0.8: # alpha = 0.9 # sigma = (1 - relative_t)**0.1 # # print(p.n_iter, p.steps) # # print(p.sampler.get_sigmas(p, p.steps)) # # sigmas = p.sampler.get_sigmas(p, p.steps) # # print(torch.sqrt(1.0 + sigmas ** 2.0)) # # print(p.sd_model.alphas_cumprod) # # sigma = p.sampler.get_sigmas(p, timesteps[0].item()) # std = out.std(dim=(2, 3), keepdim=True).clip(min=alpha*sigma, max=(2 - alpha)*sigma) # mean = out.mean(dim=(2, 3), keepdim=True) # out = (out - alpha*mean) / std return out unet.forward = new_forward.__get__(unet) apply_marking_patch(p) def unpatch_unet_forward_pass(unet): if hasattr(unet, "_fabric_old_forward"): print("[FABRIC] Restoring original U-Net forward pass") unet.forward = unet._fabric_old_forward del unet._fabric_old_forward