diff --git a/scripts/patching.py b/scripts/patching.py index c2a8963..4e43a91 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -1,7 +1,10 @@ import functools +import os import torch import torchvision.transforms.functional as functional +import matplotlib.pyplot as plt +import numpy as np from modules import devices, images, shared from modules.processing import StableDiffusionProcessingTxt2Img @@ -98,146 +101,326 @@ def patch_unet_forward_pass(p, unet, params): "seed": params.tome_seed, } + x_means = [] + x_stds = [] + x_norms = [] + cond_means = [] + cond_stds = [] + cond_norms = [] + uncond_means = [] + uncond_stds = [] + uncond_norms = [] + + + DO_FABRIC = True + mean_ema = {} + def new_forward(self, x, timesteps=None, context=None, **kwargs): _, uncond_ids, context = unmark_prompt_context(context) cond_ids = [i for i in range(context.size(0)) if i not in uncond_ids] has_cond = len(cond_ids) > 0 has_uncond = len(uncond_ids) > 0 - h_latent, w_latent = x.shape[-2:] - w, h = 8 * w_latent, 8 * h_latent - if has_hires_fix and w == hr_w and h == hr_h: - if not params.feedback_during_high_res_fix: - print("[FABRIC] Skipping feedback during high-res fix") + + def plot_hist(out, filename): + plt.figure() + xs = out.detach().cpu().numpy() + for i in range(xs.shape[0]): + hist, bin_edges = np.histogram(xs[i].reshape(-1), bins=100, density=True) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2. + plt.plot(bin_centers, hist) + plt.savefig(filename) + plt.close() + + def plot_lines(ys, filename): + plt.figure() + ys = np.stack(ys, axis=1) + for i in range(ys.shape[0]): + plt.plot(ys[i].reshape(-1)) + plt.savefig(filename) + plt.close() + + + if not DO_FABRIC: + out = self._fabric_old_forward(x, timesteps, context, **kwargs) + else: + + h_latent, w_latent = x.shape[-2:] + w, h = 8 * w_latent, 8 * h_latent + if has_hires_fix and w == hr_w and h == hr_h: + if not params.feedback_during_high_res_fix: + print("[FABRIC] Skipping feedback during high-res fix") + return self._fabric_old_forward(x, timesteps, context, **kwargs) + + pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item()) + if pos_weight <= 0 and neg_weight <= 0: + return self._fabric_old_forward(x, timesteps, context, **kwargs) + + pos_latents, neg_latents = get_latents_from_params(p, params, w, h) + pos_latents = pos_latents if has_cond else [] + neg_latents = neg_latents if has_uncond else [] + all_latents = pos_latents + neg_latents + + # Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU + if shared.cmd_opts.medvram: + try: + # Trigger register_forward_pre_hook to move the model to correct device + p.sd_model.model() + except: + pass + + if len(all_latents) == 0: return self._fabric_old_forward(x, timesteps, context, **kwargs) - pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item()) - if pos_weight <= 0 and neg_weight <= 0: - return self._fabric_old_forward(x, timesteps, context, **kwargs) - pos_latents, neg_latents = get_latents_from_params(p, params, w, h) - pos_latents = pos_latents if has_cond else [] - neg_latents = neg_latents if has_uncond else [] - all_latents = pos_latents + neg_latents + # ## intervention 2: std decay + # std = x.std(dim=(2, 3), keepdim=True) + # mask = (x.std(dim=(2, 3), keepdim=True) > 1.0).float() + # x = mask * ((0.05*(std - 1) + 1) * x / std) + (1 - mask) * x + + # ## intervention 3: mean clamp + # relative_t = timesteps[0].item() / (p.sd_model.num_timesteps - 1) + # max_mean = 0.5 * (1 - relative_t) + # mean = x.mean(dim=(2, 3), keepdim=True) + # x = x - mean + torch.clamp(mean, min=-max_mean, max=max_mean) + + # ## intervention 7: mean decay + # mean = x.mean(dim=(2, 3), keepdim=True) + # x = x - 0.2 * mean + + # add noise to reference latents + all_zs = [] + for latent in all_latents: + z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0] + all_zs.append(z) + all_zs = torch.stack(all_zs, dim=0) + + # save original forward pass + for module in self.modules(): + if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"): + module.attn1._fabric_old_forward = module.attn1.forward + module.attn2._fabric_old_forward = module.attn2.forward - # Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU - if shared.cmd_opts.medvram: try: - # Trigger register_forward_pre_hook to move the model to correct device - p.sd_model.model() - except: - pass - - if len(all_latents) == 0: - return self._fabric_old_forward(x, timesteps, context, **kwargs) - - # add noise to reference latents - all_zs = [] - for latent in all_latents: - z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0] - all_zs.append(z) - all_zs = torch.stack(all_zs, dim=0) - - # save original forward pass - for module in self.modules(): - if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"): - module.attn1._fabric_old_forward = module.attn1.forward - module.attn2._fabric_old_forward = module.attn2.forward - - try: - ## cache hidden states - cached_hiddens = {} - def patched_attn1_forward(attn1, layer_idx, x, **kwargs): - merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio) - x = merge(x) - if layer_idx not in cached_hiddens: - cached_hiddens[layer_idx] = x.detach().clone().cpu() - else: - cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0) - out = attn1._fabric_old_forward(x, **kwargs) - out = unmerge(out) - return out - - def patched_attn2_forward(attn2, x, **kwargs): - merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio) - x = merge(x) - out = attn2._fabric_old_forward(x, **kwargs) - out = unmerge(out) - return out - - # patch forward pass to cache hidden states - layer_idx = 0 - for module in self.modules(): - if isinstance(module, BasicTransformerBlock): - module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) - module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2) - layer_idx += 1 - - # run forward pass just to cache hidden states, output is discarded - for i in range(0, len(all_zs), batch_size): - zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype) - ts = timesteps[:1].expand(zs.size(0)) # (bs,) - # use the null prompt for pre-computing hidden states on feedback images - ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim) - _ = self._fabric_old_forward(zs, ts, ctx) - - num_pos = len(pos_latents) - num_neg = len(neg_latents) - num_cond = len(cond_ids) - num_uncond = len(uncond_ids) - tome_h_latent = h_latent * (1 - params.tome_ratio) - - def patched_attn1_forward(attn1, idx, x, context=None, **kwargs): - if context is None: - context = x - - cached_hs = cached_hiddens[idx].to(x.device) - - d_model = x.shape[-1] - - def attention_with_feedback(_x, context, feedback_hs, w): - num_xs, num_fb = _x.shape[0], feedback_hs.shape[0] - if num_fb > 0: - feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim) - merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens) - feedback_ctx = merge(feedback_ctx) - ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim) - weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,) - weights[_x.shape[1]:] = w + ## cache hidden states + cached_hiddens = {} + def patched_attn1_forward(attn1, layer_idx, x, **kwargs): + merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio) + x = merge(x) + if layer_idx not in cached_hiddens: + cached_hiddens[layer_idx] = x.detach().clone().cpu() else: - ctx = context - weights = None - return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim) + cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0) + out = attn1._fabric_old_forward(x, **kwargs) + out = unmerge(out) + return out + + def patched_attn2_forward(attn2, x, **kwargs): + merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio) + x = merge(x) + out = attn2._fabric_old_forward(x, **kwargs) + out = unmerge(out) + return out - outs = [] - if num_cond > 0: - out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim) - outs.append(out_cond) - if num_uncond > 0: - out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim) - outs.append(out_uncond) - out = torch.cat(outs, dim=0) - return out + # patch forward pass to cache hidden states + layer_idx = 0 + for module in self.modules(): + if isinstance(module, BasicTransformerBlock): + module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) + module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2) + layer_idx += 1 - # patch forward pass to inject cached hidden states - layer_idx = 0 - for module in self.modules(): - if isinstance(module, BasicTransformerBlock): - module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) - layer_idx += 1 + # run forward pass just to cache hidden states, output is discarded + for i in range(0, len(all_zs), batch_size): + zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype) + ts = timesteps[:1].expand(zs.size(0)) # (bs,) + # use the null prompt for pre-computing hidden states on feedback images + ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim) + _ = self._fabric_old_forward(zs, ts, ctx) - # run forward pass with cached hidden states - out = self._fabric_old_forward(x, timesteps, context, **kwargs) + num_pos = len(pos_latents) + num_neg = len(neg_latents) + num_cond = len(cond_ids) + num_uncond = len(uncond_ids) + tome_h_latent = h_latent * (1 - params.tome_ratio) - finally: - # restore original pass - for module in self.modules(): - if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"): - module.attn1.forward = module.attn1._fabric_old_forward - del module.attn1._fabric_old_forward - if isinstance(module, BasicTransformerBlock) and hasattr(module.attn2, "_fabric_old_forward"): - module.attn2.forward = module.attn2._fabric_old_forward - del module.attn2._fabric_old_forward + def patched_attn1_forward(attn1, idx, x, context=None, **kwargs): + if context is None: + context = x + + cached_hs = cached_hiddens[idx].to(x.device) + + d_model = x.shape[-1] + + def attention_with_feedback(_x, context, feedback_hs, w): + num_xs, num_fb = _x.shape[0], feedback_hs.shape[0] + if num_fb > 0: + feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim) + merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens) + feedback_ctx = merge(feedback_ctx) + ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim) + weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,) + weights[_x.shape[1]:] = w + else: + ctx = context + weights = None + return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim) + + outs = [] + if num_cond > 0: + out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim) + outs.append(out_cond) + if num_uncond > 0: + out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim) + outs.append(out_uncond) + out = torch.cat(outs, dim=0) + return out + + # patch forward pass to inject cached hidden states + layer_idx = 0 + for module in self.modules(): + if isinstance(module, BasicTransformerBlock): + module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) + layer_idx += 1 + + # run forward pass with cached hidden states + out = self._fabric_old_forward(x, timesteps, context, **kwargs) + + finally: + # restore original pass + for module in self.modules(): + if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"): + module.attn1.forward = module.attn1._fabric_old_forward + del module.attn1._fabric_old_forward + if isinstance(module, BasicTransformerBlock) and hasattr(module.attn2, "_fabric_old_forward"): + module.attn2.forward = module.attn2._fabric_old_forward + del module.attn2._fabric_old_forward + + + + + cond_outs = out[cond_ids] + uncond_outs = out[uncond_ids] + xs = x[cond_ids] + t = int(timesteps[0].item()) + seed = p.seed + + x_means.append(xs.mean(dim=(2, 3)).view(-1).cpu().numpy()) + x_stds.append(xs.std(dim=(2, 3)).view(-1).cpu().numpy()) + x_norms.append(xs.norm(dim=(2, 3)).view(-1).cpu().numpy()) + cond_means.append(cond_outs.mean(dim=(2, 3)).view(-1).cpu().numpy()) + cond_stds.append(cond_outs.std(dim=(2, 3)).view(-1).cpu().numpy()) + cond_norms.append(cond_outs.norm(dim=(2, 3)).view(-1).cpu().numpy()) + uncond_means.append(uncond_outs.mean(dim=(2, 3)).view(-1).cpu().numpy()) + uncond_stds.append(uncond_outs.std(dim=(2, 3)).view(-1).cpu().numpy()) + uncond_norms.append(uncond_outs.norm(dim=(2, 3)).view(-1).cpu().numpy()) + + + def plot_trajectory(means, stds, ax=None): + if ax is None: + ax = plt.gca() + # means = np.stack(means, axis=1) + # stds = np.stack(stds, axis=1) + ax.plot(means) + ax.fill_between(range(means.shape[0]), means - stds, means + stds, alpha=0.3) + + def select(xs, batch_idx): + return np.stack([x[batch_idx] for x in xs], axis=0) + + for i in range(len(xs)): + x_means_i = select(x_means, i) + x_stds_i = select(x_stds, i) + cond_means_i = select(cond_means, i) + cond_stds_i = select(cond_stds, i) + uncond_means_i = select(uncond_means, i) + uncond_stds_i = select(uncond_stds, i) + + plot_file = f"plots/stats/{seed + i}_{'fabric' if DO_FABRIC else 'default'}.png" + fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + plot_trajectory(x_means_i, x_stds_i, ax=axs[0]) + plot_trajectory(cond_means_i, cond_stds_i, ax=axs[1]) + plot_trajectory(uncond_means_i, uncond_stds_i, ax=axs[2]) + fig.savefig(plot_file) + + + # os.makedirs("plots/lines", exist_ok=True) + # plot_lines(x_means, f"plots/lines/{seed}_x_means{'_fabric' if DO_FABRIC else ''}.png") + # plot_lines(x_stds, f"plots/lines/{seed}_x_stds{'_fabric' if DO_FABRIC else ''}.png") + # plot_lines(x_norms, f"plots/lines/{seed}_x_norms{'_fabric' if DO_FABRIC else ''}.png") + # plot_lines(cond_means, f"plots/lines/{seed}_cond_means{'_fabric' if DO_FABRIC else ''}.png") + # plot_lines(cond_stds, f"plots/lines/{seed}_cond_stds{'_fabric' if DO_FABRIC else ''}.png") + # plot_lines(cond_norms, f"plots/lines/{seed}_cond_norms{'_fabric' if DO_FABRIC else ''}.png") + # plot_lines(uncond_means, f"plots/lines/{seed}_uncond_means{'_fabric' if DO_FABRIC else ''}.png") + # plot_lines(uncond_stds, f"plots/lines/{seed}_uncond_stds{'_fabric' if DO_FABRIC else ''}.png") + # plot_lines(uncond_norms, f"plots/lines/{seed}_uncond_norms{'_fabric' if DO_FABRIC else ''}.png") + + + # os.makedirs("plots/hists", exist_ok=True) + # for i in range(len(xs)): + # filename = f"plots/hists/{seed}_{t}_x_{i}{'_fabric' if DO_FABRIC else ''}.png" + # plot_hist(x[i], filename) + + # for i in range(len(cond_outs)): + # filename = f"plots/hists/{seed}_{t}_cond_{i}{'_fabric' if DO_FABRIC else ''}.png" + # plot_hist(cond_outs[i], filename) + + # for i in range(len(uncond_outs)): + # filename = f"plots/hists/{seed}_{t}_uncond_{i}{'_fabric' if DO_FABRIC else ''}.png" + # plot_hist(uncond_outs[i], filename) + + # ## intervention 1: mean decay + # out = out - 0.5*out.mean(dim=(2, 3), keepdim=True) + + # ## intervention 4: early normalization + # relative_t = 1 - timesteps[0].item() / (p.sd_model.num_timesteps - 1) + # if relative_t < 0.33: + # out = (out - out.mean(dim=(2, 3), keepdim=True)) / out.std(dim=(2, 3), keepdim=True) + + # ## intervention 5: momentum mean decay + # dampen = 0.8 + # accel = 0.35 + # alpha = 0.5 + # beta = 0.25 + # mean = out.mean(dim=(2, 3), keepdim=True) + # if "momentum" not in mean_ema: + # mean_ema["momentum"] = mean + # else: + # mean_ema["momentum"] = dampen * (mean_ema["momentum"] + accel * out.mean(dim=(2, 3), keepdim=True)) + # out = out - alpha * mean - beta * mean_ema["momentum"] + + # print() + # print("mean: ", mean.cpu().view(-1)) + # print("momentum:", mean_ema["momentum"].cpu().view(-1)) + + # plot_lines([p.sd_model.betas.cpu().unsqueeze(0).numpy()], f"plots/lines/betas.png") + # plot_lines([p.sd_model.alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/alphas_cumprod.png") + # plot_lines([p.sd_model.alphas_cumprod_prev.cpu().unsqueeze(0).numpy()], f"plots/lines/alphas_cumprod_prev.png") + # plot_lines([p.sd_model.posterior_variance.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_variance.png") + # plot_lines([p.sd_model.posterior_log_variance_clipped.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_log_variance_clipped.png") + # plot_lines([p.sd_model.posterior_mean_coef1.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_mean_coef1.png") + # plot_lines([p.sd_model.posterior_mean_coef2.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_mean_coef2.png") + + # plot_lines([p.sd_model.sqrt_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_alphas_cumprod.png") + # plot_lines([p.sd_model.sqrt_one_minus_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_one_minus_alphas_cumprod.png") + # plot_lines([p.sd_model.log_one_minus_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/log_one_minus_alphas_cumprod.png") + # plot_lines([p.sd_model.sqrt_recip_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_recip_alphas_cumprod.png") + # plot_lines([p.sd_model.sqrt_recipm1_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_recipm1_alphas_cumprod.png") + + # ## intervention 8: dynamic standardization + # # TODO: test how to dynamically bound the mean rather than always subtracting some fraction + # relative_t = 1 - timesteps[0].item() / (p.sd_model.num_timesteps - 1) + # if relative_t < 0.8: + # alpha = 0.9 + # sigma = (1 - relative_t)**0.1 + # # print(p.n_iter, p.steps) + # # print(p.sampler.get_sigmas(p, p.steps)) + # # sigmas = p.sampler.get_sigmas(p, p.steps) + # # print(torch.sqrt(1.0 + sigmas ** 2.0)) + # # print(p.sd_model.alphas_cumprod) + # # sigma = p.sampler.get_sigmas(p, timesteps[0].item()) + # std = out.std(dim=(2, 3), keepdim=True).clip(min=alpha*sigma, max=(2 - alpha)*sigma) + # mean = out.mean(dim=(2, 3), keepdim=True) + # out = (out - alpha*mean) / std return out