diff --git a/scripts/fabric.py b/scripts/fabric.py index 1cd1117..d3c2b68 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -139,6 +139,7 @@ class FabricParams: tome_ratio: float = 0.5 tome_max_tokens: int = 4*4096 tome_seed: int = -1 + burnout_protection: bool = False # TODO: replace global state with Gradio state @@ -218,6 +219,7 @@ class FabricScript(modules.scripts.Script): feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight") tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False) + burnout_protection = gr.Checkbox(label="Burnout protection (enable if results contain artifacts or are especially dark)", value=False) with gr.Accordion("Advanced options", open=DEBUG): with FormGroup(): @@ -302,6 +304,7 @@ class FabricScript(modules.scripts.Script): (tome_ratio, "fabric_tome_ratio"), (tome_max_tokens, "fabric_tome_max_tokens"), (tome_seed, "fabric_tome_seed"), + (burnout_protection, "fabric_burnout_protection"), (feedback_during_high_res_fix, "fabric_feedback_during_high_res_fix"), (liked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_pos_images")) if "fabric_pos_images" in d else None), (disliked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_neg_images")) if "fabric_neg_images" in d else None), @@ -323,6 +326,7 @@ class FabricScript(modules.scripts.Script): tome_ratio, tome_max_tokens, tome_seed, + burnout_protection, ] @@ -429,6 +433,7 @@ class FabricScript(modules.scripts.Script): tome_ratio, tome_max_tokens, tome_seed, + burnout_protection, ) = args # restore original U-Net forward pass in case previous batch errored out @@ -454,6 +459,7 @@ class FabricScript(modules.scripts.Script): tome_ratio=(round(tome_ratio * 16) / 16), tome_max_tokens=tome_max_tokens, tome_seed=get_fixed_seed(int(tome_seed)), + burnout_protection=burnout_protection, ) diff --git a/scripts/patching.py b/scripts/patching.py index 3080bb2..a298a45 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -1,10 +1,7 @@ import functools -import os import torch import torchvision.transforms.functional as functional -import matplotlib.pyplot as plt -import numpy as np from modules import devices, images, shared from modules.processing import StableDiffusionProcessingTxt2Img @@ -159,327 +156,174 @@ def patch_unet_forward_pass(p, unet, params): "seed": params.tome_seed, } - x_means = [] - x_stds = [] - x_norms = [] - cond_means = [] - cond_stds = [] - cond_norms = [] - uncond_means = [] - uncond_stds = [] - uncond_norms = [] - - - DO_FABRIC = True - mean_ema = {} + prev_vals = { + "weight_modifier": 1.0, + } def new_forward(self, x, timesteps=None, context=None, **kwargs): _, uncond_ids, cond_ids, context = unmark_prompt_context(context) has_cond = len(cond_ids) > 0 has_uncond = len(uncond_ids) > 0 - - def plot_hist(out, filename): - plt.figure() - xs = out.detach().cpu().numpy() - for i in range(xs.shape[0]): - hist, bin_edges = np.histogram(xs[i].reshape(-1), bins=100, density=True) - bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2. - plt.plot(bin_centers, hist) - plt.savefig(filename) - plt.close() - - def plot_lines(ys, filename): - plt.figure() - ys = np.stack(ys, axis=1) - for i in range(ys.shape[0]): - plt.plot(ys[i].reshape(-1)) - plt.savefig(filename) - plt.close() - - - if not DO_FABRIC: - out = self._fabric_old_forward(x, timesteps, context, **kwargs) - else: - - h_latent, w_latent = x.shape[-2:] - w, h = 8 * w_latent, 8 * h_latent - if has_hires_fix and w == hr_w and h == hr_h: - if not params.feedback_during_high_res_fix: - print("[FABRIC] Skipping feedback during high-res fix") - return self._fabric_old_forward(x, timesteps, context, **kwargs) - - pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item()) - if pos_weight <= 0 and neg_weight <= 0: - return self._fabric_old_forward(x, timesteps, context, **kwargs) - - pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item(), num_timesteps=num_timesteps) - if pos_weight <= 0 and neg_weight <= 0: + h_latent, w_latent = x.shape[-2:] + w, h = 8 * w_latent, 8 * h_latent + if has_hires_fix and w == hr_w and h == hr_h: + if not params.feedback_during_high_res_fix: + print("[FABRIC] Skipping feedback during high-res fix") return self._fabric_old_forward(x, timesteps, context, **kwargs) + + pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item(), num_timesteps=num_timesteps) + if pos_weight <= 0 and neg_weight <= 0: + return self._fabric_old_forward(x, timesteps, context, **kwargs) - pos_latents, neg_latents = get_latents_from_params(p, params, w, h) - pos_latents = pos_latents if has_cond else [] - neg_latents = neg_latents if has_uncond else [] - all_latents = pos_latents + neg_latents + if params.burnout_protection and "cond" in prev_vals and "uncond" in prev_vals: + # burnout protection: if the difference betwen cond/uncond was too high in the previous step (sign of instability), slash the weight modifier + diff_std = (prev_vals["cond"] - prev_vals["uncond"]).std(dim=(2, 3)).max().item() + diff_abs_mean = (prev_vals["cond"] - prev_vals["uncond"]).mean(dim=(2, 3)).abs().max().item() + if diff_std > 0.06 or diff_abs_mean > 0.02: + prev_vals["weight_modifier"] *= 0.5 + else: + prev_vals["weight_modifier"] = min(1.0, 1.5 * prev_vals["weight_modifier"]) + + pos_weight, neg_weight = pos_weight * prev_vals["weight_modifier"], neg_weight * prev_vals["weight_modifier"] - # Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU - if shared.cmd_opts.medvram: - try: - # Trigger register_forward_pre_hook to move the model to correct device - p.sd_model.model() - except: - pass - - if len(all_latents) == 0: - return self._fabric_old_forward(x, timesteps, context, **kwargs) - - - # ## intervention 2: std decay - # std = x.std(dim=(2, 3), keepdim=True) - # mask = (x.std(dim=(2, 3), keepdim=True) > 1.0).float() - # x = mask * ((0.05*(std - 1) + 1) * x / std) + (1 - mask) * x - - # ## intervention 3: mean clamp - # relative_t = timesteps[0].item() / (p.sd_model.num_timesteps - 1) - # max_mean = 0.5 * (1 - relative_t) - # mean = x.mean(dim=(2, 3), keepdim=True) - # x = x - mean + torch.clamp(mean, min=-max_mean, max=max_mean) - - # ## intervention 7: mean decay - # mean = x.mean(dim=(2, 3), keepdim=True) - # x = x - 0.2 * mean - - # add noise to reference latents - xs_0 = torch.stack(all_latents, dim=0) - ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,) - all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long()) - - # save original forward pass - for module in self.modules(): - if isinstance(module, transformer_block_type) and not hasattr(module.attn1, "_fabric_old_forward"): - module.attn1._fabric_old_forward = module.attn1.forward - module.attn2._fabric_old_forward = module.attn2.forward + pos_latents, neg_latents = get_latents_from_params(p, params, w, h) + pos_latents = pos_latents if has_cond else [] + neg_latents = neg_latents if has_uncond else [] + all_latents = pos_latents + neg_latents + # Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU + if shared.cmd_opts.medvram: try: - ## cache hidden states - cached_hiddens = {} - def patched_attn1_forward(attn1, layer_idx, x, **kwargs): - merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio) - x = merge(x) - if layer_idx not in cached_hiddens: - cached_hiddens[layer_idx] = x.detach().clone().cpu() + # Trigger register_forward_pre_hook to move the model to correct device + p.sd_model.model() + except: + pass + + if len(all_latents) == 0: + return self._fabric_old_forward(x, timesteps, context, **kwargs) + + # add noise to reference latents + xs_0 = torch.stack(all_latents, dim=0) + ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,) + all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long()) + + # save original forward pass + for module in self.modules(): + if isinstance(module, transformer_block_type) and not hasattr(module.attn1, "_fabric_old_forward"): + module.attn1._fabric_old_forward = module.attn1.forward + module.attn2._fabric_old_forward = module.attn2.forward + + try: + ## cache hidden states + cached_hiddens = {} + def patched_attn1_forward(attn1, layer_idx, x, **kwargs): + merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio) + x = merge(x) + if layer_idx not in cached_hiddens: + cached_hiddens[layer_idx] = x.detach().clone().cpu() + else: + cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0) + out = attn1._fabric_old_forward(x, **kwargs) + out = unmerge(out) + return out + + def patched_attn2_forward(attn2, x, **kwargs): + merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio) + x = merge(x) + out = attn2._fabric_old_forward(x, **kwargs) + out = unmerge(out) + return out + + # patch forward pass to cache hidden states + layer_idx = 0 + for module in self.modules(): + if isinstance(module, transformer_block_type): + module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) + module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2) + layer_idx += 1 + + # run forward pass just to cache hidden states, output is discarded + for i in range(0, len(all_zs), batch_size): + zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype) + ts = timesteps[:1].expand(zs.size(0)) # (bs,) + # use the null prompt for pre-computing hidden states on feedback images + ctx_args = {} + if sd_version == SD15: + ctx_args["context"] = null_ctx.expand(zs.size(0), -1, -1) # (bs, seq_len, d_model) + else: # SDXL + ctx_args["context"] = null_ctx["crossattn"].expand(zs.size(0), -1, -1) # (bs, seq_len, d_model) + ctx_args["y"] = null_ctx["vector"].expand(zs.size(0), -1) # (bs, d_vector) + _ = self._fabric_old_forward(zs, ts, **ctx_args) + + num_pos = len(pos_latents) + num_neg = len(neg_latents) + num_cond = len(cond_ids) + num_uncond = len(uncond_ids) + tome_h_latent = h_latent * (1 - params.tome_ratio) + + def patched_attn1_forward(attn1, idx, x, context=None, **kwargs): + if context is None: + context = x + + cached_hs = cached_hiddens[idx].to(x.device) + + d_model = x.shape[-1] + + def attention_with_feedback(_x, context, feedback_hs, w): + num_xs, num_fb = _x.shape[0], feedback_hs.shape[0] + if num_fb > 0: + feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim) + merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens) + feedback_ctx = merge(feedback_ctx) + ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim) + weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,) + weights[_x.shape[1]:] = w else: - cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0) - out = attn1._fabric_old_forward(x, **kwargs) - out = unmerge(out) - return out - - def patched_attn2_forward(attn2, x, **kwargs): - merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio) - x = merge(x) - out = attn2._fabric_old_forward(x, **kwargs) - out = unmerge(out) - return out + ctx = context + weights = None + return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim) - # patch forward pass to cache hidden states - layer_idx = 0 - for module in self.modules(): - if isinstance(module, transformer_block_type): - module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) - module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2) - layer_idx += 1 + out = torch.zeros_like(x, dtype=devices.dtype_unet) + if num_cond > 0: + out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim) + out[cond_ids] = out_cond + if num_uncond > 0: + out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim) + out[uncond_ids] = out_uncond + return out - # run forward pass just to cache hidden states, output is discarded - for i in range(0, len(all_zs), batch_size): - zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype) - ts = timesteps[:1].expand(zs.size(0)) # (bs,) - # use the null prompt for pre-computing hidden states on feedback images - ctx_args = {} - if sd_version == SD15: - ctx_args["context"] = null_ctx.expand(zs.size(0), -1, -1) # (bs, seq_len, d_model) - else: # SDXL - ctx_args["context"] = null_ctx["crossattn"].expand(zs.size(0), -1, -1) # (bs, seq_len, d_model) - ctx_args["y"] = null_ctx["vector"].expand(zs.size(0), -1) # (bs, d_vector) - _ = self._fabric_old_forward(zs, ts, **ctx_args) + # patch forward pass to inject cached hidden states + layer_idx = 0 + for module in self.modules(): + if isinstance(module, transformer_block_type): + module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) + layer_idx += 1 - num_pos = len(pos_latents) - num_neg = len(neg_latents) - num_cond = len(cond_ids) - num_uncond = len(uncond_ids) - tome_h_latent = h_latent * (1 - params.tome_ratio) + # run forward pass with cached hidden states + out = self._fabric_old_forward(x, timesteps, context, **kwargs) - def patched_attn1_forward(attn1, idx, x, context=None, **kwargs): - if context is None: - context = x + cond_outs = out[cond_ids] + uncond_outs = out[uncond_ids] + + if has_cond: + prev_vals["cond"] = cond_outs.detach().clone() + if has_uncond: + prev_vals["uncond"] = uncond_outs.detach().clone() - cached_hs = cached_hiddens[idx].to(x.device) - - d_model = x.shape[-1] - - def attention_with_feedback(_x, context, feedback_hs, w): - num_xs, num_fb = _x.shape[0], feedback_hs.shape[0] - if num_fb > 0: - feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim) - merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens) - feedback_ctx = merge(feedback_ctx) - ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim) - weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,) - weights[_x.shape[1]:] = w - else: - ctx = context - weights = None - return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim) - - out = torch.zeros_like(x, dtype=devices.unet_dtype) - if num_cond > 0: - out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim) - out[cond_ids] = out_cond - if num_uncond > 0: - out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim) - out[uncond_ids] = out_uncond - return out - - # patch forward pass to inject cached hidden states - layer_idx = 0 - for module in self.modules(): - if isinstance(module, transformer_block_type): - module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx) - layer_idx += 1 - - # run forward pass with cached hidden states - out = self._fabric_old_forward(x, timesteps, context, **kwargs) - - cond_outs = out[cond_ids] - uncond_outs = out[uncond_ids] - xs = x[cond_ids] - t = int(timesteps[0].item()) - seed = p.seed - - x_means.append(xs.mean(dim=(2, 3)).view(-1).cpu().numpy()) - x_stds.append(xs.std(dim=(2, 3)).view(-1).cpu().numpy()) - x_norms.append(xs.norm(dim=(2, 3)).view(-1).cpu().numpy()) - cond_means.append(cond_outs.mean(dim=(2, 3)).view(-1).cpu().numpy()) - cond_stds.append(cond_outs.std(dim=(2, 3)).view(-1).cpu().numpy()) - cond_norms.append(cond_outs.norm(dim=(2, 3)).view(-1).cpu().numpy()) - uncond_means.append(uncond_outs.mean(dim=(2, 3)).view(-1).cpu().numpy()) - uncond_stds.append(uncond_outs.std(dim=(2, 3)).view(-1).cpu().numpy()) - uncond_norms.append(uncond_outs.norm(dim=(2, 3)).view(-1).cpu().numpy()) - - - def plot_trajectory(means, stds, ax=None): - if ax is None: - ax = plt.gca() - # means = np.stack(means, axis=1) - # stds = np.stack(stds, axis=1) - ax.plot(means) - ax.fill_between(range(means.shape[0]), means - stds, means + stds, alpha=0.3) - - def select(xs, batch_idx): - return np.stack([x[batch_idx] for x in xs], axis=0) - - for i in range(len(xs)): - x_means_i = select(x_means, i) - x_stds_i = select(x_stds, i) - cond_means_i = select(cond_means, i) - cond_stds_i = select(cond_stds, i) - uncond_means_i = select(uncond_means, i) - uncond_stds_i = select(uncond_stds, i) - - plot_file = f"plots/stats/{seed + i}_{'fabric' if DO_FABRIC else 'default'}.png" - fig, axs = plt.subplots(1, 3, figsize=(15, 5)) - plot_trajectory(x_means_i, x_stds_i, ax=axs[0]) - plot_trajectory(cond_means_i, cond_stds_i, ax=axs[1]) - plot_trajectory(uncond_means_i, uncond_stds_i, ax=axs[2]) - fig.savefig(plot_file) - - - # os.makedirs("plots/lines", exist_ok=True) - # plot_lines(x_means, f"plots/lines/{seed}_x_means{'_fabric' if DO_FABRIC else ''}.png") - # plot_lines(x_stds, f"plots/lines/{seed}_x_stds{'_fabric' if DO_FABRIC else ''}.png") - # plot_lines(x_norms, f"plots/lines/{seed}_x_norms{'_fabric' if DO_FABRIC else ''}.png") - # plot_lines(cond_means, f"plots/lines/{seed}_cond_means{'_fabric' if DO_FABRIC else ''}.png") - # plot_lines(cond_stds, f"plots/lines/{seed}_cond_stds{'_fabric' if DO_FABRIC else ''}.png") - # plot_lines(cond_norms, f"plots/lines/{seed}_cond_norms{'_fabric' if DO_FABRIC else ''}.png") - # plot_lines(uncond_means, f"plots/lines/{seed}_uncond_means{'_fabric' if DO_FABRIC else ''}.png") - # plot_lines(uncond_stds, f"plots/lines/{seed}_uncond_stds{'_fabric' if DO_FABRIC else ''}.png") - # plot_lines(uncond_norms, f"plots/lines/{seed}_uncond_norms{'_fabric' if DO_FABRIC else ''}.png") - - - # os.makedirs("plots/hists", exist_ok=True) - # for i in range(len(xs)): - # filename = f"plots/hists/{seed}_{t}_x_{i}{'_fabric' if DO_FABRIC else ''}.png" - # plot_hist(x[i], filename) - - # for i in range(len(cond_outs)): - # filename = f"plots/hists/{seed}_{t}_cond_{i}{'_fabric' if DO_FABRIC else ''}.png" - # plot_hist(cond_outs[i], filename) - - # for i in range(len(uncond_outs)): - # filename = f"plots/hists/{seed}_{t}_uncond_{i}{'_fabric' if DO_FABRIC else ''}.png" - # plot_hist(uncond_outs[i], filename) - - # ## intervention 1: mean decay - # out = out - 0.5*out.mean(dim=(2, 3), keepdim=True) - - # ## intervention 4: early normalization - # relative_t = 1 - timesteps[0].item() / (p.sd_model.num_timesteps - 1) - # if relative_t < 0.33: - # out = (out - out.mean(dim=(2, 3), keepdim=True)) / out.std(dim=(2, 3), keepdim=True) - - # ## intervention 5: momentum mean decay - # dampen = 0.8 - # accel = 0.35 - # alpha = 0.5 - # beta = 0.25 - # mean = out.mean(dim=(2, 3), keepdim=True) - # if "momentum" not in mean_ema: - # mean_ema["momentum"] = mean - # else: - # mean_ema["momentum"] = dampen * (mean_ema["momentum"] + accel * out.mean(dim=(2, 3), keepdim=True)) - # out = out - alpha * mean - beta * mean_ema["momentum"] - - # print() - # print("mean: ", mean.cpu().view(-1)) - # print("momentum:", mean_ema["momentum"].cpu().view(-1)) - - # plot_lines([p.sd_model.betas.cpu().unsqueeze(0).numpy()], f"plots/lines/betas.png") - # plot_lines([p.sd_model.alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/alphas_cumprod.png") - # plot_lines([p.sd_model.alphas_cumprod_prev.cpu().unsqueeze(0).numpy()], f"plots/lines/alphas_cumprod_prev.png") - # plot_lines([p.sd_model.posterior_variance.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_variance.png") - # plot_lines([p.sd_model.posterior_log_variance_clipped.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_log_variance_clipped.png") - # plot_lines([p.sd_model.posterior_mean_coef1.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_mean_coef1.png") - # plot_lines([p.sd_model.posterior_mean_coef2.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_mean_coef2.png") - - # plot_lines([p.sd_model.sqrt_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_alphas_cumprod.png") - # plot_lines([p.sd_model.sqrt_one_minus_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_one_minus_alphas_cumprod.png") - # plot_lines([p.sd_model.log_one_minus_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/log_one_minus_alphas_cumprod.png") - # plot_lines([p.sd_model.sqrt_recip_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_recip_alphas_cumprod.png") - # plot_lines([p.sd_model.sqrt_recipm1_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_recipm1_alphas_cumprod.png") - - # ## intervention 8: dynamic standardization - # # TODO: test how to dynamically bound the mean rather than always subtracting some fraction - # relative_t = 1 - timesteps[0].item() / (p.sd_model.num_timesteps - 1) - # if relative_t < 0.8: - # alpha = 0.9 - # sigma = (1 - relative_t)**0.1 - # # print(p.n_iter, p.steps) - # # print(p.sampler.get_sigmas(p, p.steps)) - # # sigmas = p.sampler.get_sigmas(p, p.steps) - # # print(torch.sqrt(1.0 + sigmas ** 2.0)) - # # print(p.sd_model.alphas_cumprod) - # # sigma = p.sampler.get_sigmas(p, timesteps[0].item()) - # std = out.std(dim=(2, 3), keepdim=True).clip(min=alpha*sigma, max=(2 - alpha)*sigma) - # mean = out.mean(dim=(2, 3), keepdim=True) - # out = (out - alpha*mean) / std - finally: - # restore original pass - for module in self.modules(): - if isinstance(module, transformer_block_type) and hasattr(module.attn1, "_fabric_old_forward"): - module.attn1.forward = module.attn1._fabric_old_forward - del module.attn1._fabric_old_forward - if isinstance(module, transformer_block_type) and hasattr(module.attn2, "_fabric_old_forward"): - module.attn2.forward = module.attn2._fabric_old_forward - del module.attn2._fabric_old_forward + if params.burnout_protection: + # burnout protection: recenter the output to prevent instabilities caused by mean drift + mean = out.mean(dim=(2, 3), keepdim=True) + out = out - 0.5 * mean + finally: + # restore original pass + for module in self.modules(): + if isinstance(module, transformer_block_type) and hasattr(module.attn1, "_fabric_old_forward"): + module.attn1.forward = module.attn1._fabric_old_forward + del module.attn1._fabric_old_forward + if isinstance(module, transformer_block_type) and hasattr(module.attn2, "_fabric_old_forward"): + module.attn2.forward = module.attn2._fabric_old_forward + del module.attn2._fabric_old_forward return out