add burnout protection

pull/38/head
dvruette 2024-03-08 02:01:09 +01:00
parent 31319d5b12
commit ab78cc0eae
2 changed files with 157 additions and 307 deletions

View File

@ -139,6 +139,7 @@ class FabricParams:
tome_ratio: float = 0.5
tome_max_tokens: int = 4*4096
tome_seed: int = -1
burnout_protection: bool = False
# TODO: replace global state with Gradio state
@ -218,6 +219,7 @@ class FabricScript(modules.scripts.Script):
feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight")
tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False)
burnout_protection = gr.Checkbox(label="Burnout protection (enable if results contain artifacts or are especially dark)", value=False)
with gr.Accordion("Advanced options", open=DEBUG):
with FormGroup():
@ -302,6 +304,7 @@ class FabricScript(modules.scripts.Script):
(tome_ratio, "fabric_tome_ratio"),
(tome_max_tokens, "fabric_tome_max_tokens"),
(tome_seed, "fabric_tome_seed"),
(burnout_protection, "fabric_burnout_protection"),
(feedback_during_high_res_fix, "fabric_feedback_during_high_res_fix"),
(liked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_pos_images")) if "fabric_pos_images" in d else None),
(disliked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_neg_images")) if "fabric_neg_images" in d else None),
@ -323,6 +326,7 @@ class FabricScript(modules.scripts.Script):
tome_ratio,
tome_max_tokens,
tome_seed,
burnout_protection,
]
@ -429,6 +433,7 @@ class FabricScript(modules.scripts.Script):
tome_ratio,
tome_max_tokens,
tome_seed,
burnout_protection,
) = args
# restore original U-Net forward pass in case previous batch errored out
@ -454,6 +459,7 @@ class FabricScript(modules.scripts.Script):
tome_ratio=(round(tome_ratio * 16) / 16),
tome_max_tokens=tome_max_tokens,
tome_seed=get_fixed_seed(int(tome_seed)),
burnout_protection=burnout_protection,
)

View File

@ -1,10 +1,7 @@
import functools
import os
import torch
import torchvision.transforms.functional as functional
import matplotlib.pyplot as plt
import numpy as np
from modules import devices, images, shared
from modules.processing import StableDiffusionProcessingTxt2Img
@ -159,327 +156,174 @@ def patch_unet_forward_pass(p, unet, params):
"seed": params.tome_seed,
}
x_means = []
x_stds = []
x_norms = []
cond_means = []
cond_stds = []
cond_norms = []
uncond_means = []
uncond_stds = []
uncond_norms = []
DO_FABRIC = True
mean_ema = {}
prev_vals = {
"weight_modifier": 1.0,
}
def new_forward(self, x, timesteps=None, context=None, **kwargs):
_, uncond_ids, cond_ids, context = unmark_prompt_context(context)
has_cond = len(cond_ids) > 0
has_uncond = len(uncond_ids) > 0
def plot_hist(out, filename):
plt.figure()
xs = out.detach().cpu().numpy()
for i in range(xs.shape[0]):
hist, bin_edges = np.histogram(xs[i].reshape(-1), bins=100, density=True)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.
plt.plot(bin_centers, hist)
plt.savefig(filename)
plt.close()
def plot_lines(ys, filename):
plt.figure()
ys = np.stack(ys, axis=1)
for i in range(ys.shape[0]):
plt.plot(ys[i].reshape(-1))
plt.savefig(filename)
plt.close()
if not DO_FABRIC:
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
else:
h_latent, w_latent = x.shape[-2:]
w, h = 8 * w_latent, 8 * h_latent
if has_hires_fix and w == hr_w and h == hr_h:
if not params.feedback_during_high_res_fix:
print("[FABRIC] Skipping feedback during high-res fix")
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item())
if pos_weight <= 0 and neg_weight <= 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item(), num_timesteps=num_timesteps)
if pos_weight <= 0 and neg_weight <= 0:
h_latent, w_latent = x.shape[-2:]
w, h = 8 * w_latent, 8 * h_latent
if has_hires_fix and w == hr_w and h == hr_h:
if not params.feedback_during_high_res_fix:
print("[FABRIC] Skipping feedback during high-res fix")
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item(), num_timesteps=num_timesteps)
if pos_weight <= 0 and neg_weight <= 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_latents, neg_latents = get_latents_from_params(p, params, w, h)
pos_latents = pos_latents if has_cond else []
neg_latents = neg_latents if has_uncond else []
all_latents = pos_latents + neg_latents
if params.burnout_protection and "cond" in prev_vals and "uncond" in prev_vals:
# burnout protection: if the difference betwen cond/uncond was too high in the previous step (sign of instability), slash the weight modifier
diff_std = (prev_vals["cond"] - prev_vals["uncond"]).std(dim=(2, 3)).max().item()
diff_abs_mean = (prev_vals["cond"] - prev_vals["uncond"]).mean(dim=(2, 3)).abs().max().item()
if diff_std > 0.06 or diff_abs_mean > 0.02:
prev_vals["weight_modifier"] *= 0.5
else:
prev_vals["weight_modifier"] = min(1.0, 1.5 * prev_vals["weight_modifier"])
pos_weight, neg_weight = pos_weight * prev_vals["weight_modifier"], neg_weight * prev_vals["weight_modifier"]
# Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU
if shared.cmd_opts.medvram:
try:
# Trigger register_forward_pre_hook to move the model to correct device
p.sd_model.model()
except:
pass
if len(all_latents) == 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
# ## intervention 2: std decay
# std = x.std(dim=(2, 3), keepdim=True)
# mask = (x.std(dim=(2, 3), keepdim=True) > 1.0).float()
# x = mask * ((0.05*(std - 1) + 1) * x / std) + (1 - mask) * x
# ## intervention 3: mean clamp
# relative_t = timesteps[0].item() / (p.sd_model.num_timesteps - 1)
# max_mean = 0.5 * (1 - relative_t)
# mean = x.mean(dim=(2, 3), keepdim=True)
# x = x - mean + torch.clamp(mean, min=-max_mean, max=max_mean)
# ## intervention 7: mean decay
# mean = x.mean(dim=(2, 3), keepdim=True)
# x = x - 0.2 * mean
# add noise to reference latents
xs_0 = torch.stack(all_latents, dim=0)
ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,)
all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long())
# save original forward pass
for module in self.modules():
if isinstance(module, transformer_block_type) and not hasattr(module.attn1, "_fabric_old_forward"):
module.attn1._fabric_old_forward = module.attn1.forward
module.attn2._fabric_old_forward = module.attn2.forward
pos_latents, neg_latents = get_latents_from_params(p, params, w, h)
pos_latents = pos_latents if has_cond else []
neg_latents = neg_latents if has_uncond else []
all_latents = pos_latents + neg_latents
# Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU
if shared.cmd_opts.medvram:
try:
## cache hidden states
cached_hiddens = {}
def patched_attn1_forward(attn1, layer_idx, x, **kwargs):
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
x = merge(x)
if layer_idx not in cached_hiddens:
cached_hiddens[layer_idx] = x.detach().clone().cpu()
# Trigger register_forward_pre_hook to move the model to correct device
p.sd_model.model()
except:
pass
if len(all_latents) == 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
# add noise to reference latents
xs_0 = torch.stack(all_latents, dim=0)
ts = timesteps[0, None].expand(xs_0.size(0)) # (bs,)
all_zs = forward_noise(p, xs_0, torch.round(ts.float()).long())
# save original forward pass
for module in self.modules():
if isinstance(module, transformer_block_type) and not hasattr(module.attn1, "_fabric_old_forward"):
module.attn1._fabric_old_forward = module.attn1.forward
module.attn2._fabric_old_forward = module.attn2.forward
try:
## cache hidden states
cached_hiddens = {}
def patched_attn1_forward(attn1, layer_idx, x, **kwargs):
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
x = merge(x)
if layer_idx not in cached_hiddens:
cached_hiddens[layer_idx] = x.detach().clone().cpu()
else:
cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0)
out = attn1._fabric_old_forward(x, **kwargs)
out = unmerge(out)
return out
def patched_attn2_forward(attn2, x, **kwargs):
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
x = merge(x)
out = attn2._fabric_old_forward(x, **kwargs)
out = unmerge(out)
return out
# patch forward pass to cache hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, transformer_block_type):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2)
layer_idx += 1
# run forward pass just to cache hidden states, output is discarded
for i in range(0, len(all_zs), batch_size):
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
# use the null prompt for pre-computing hidden states on feedback images
ctx_args = {}
if sd_version == SD15:
ctx_args["context"] = null_ctx.expand(zs.size(0), -1, -1) # (bs, seq_len, d_model)
else: # SDXL
ctx_args["context"] = null_ctx["crossattn"].expand(zs.size(0), -1, -1) # (bs, seq_len, d_model)
ctx_args["y"] = null_ctx["vector"].expand(zs.size(0), -1) # (bs, d_vector)
_ = self._fabric_old_forward(zs, ts, **ctx_args)
num_pos = len(pos_latents)
num_neg = len(neg_latents)
num_cond = len(cond_ids)
num_uncond = len(uncond_ids)
tome_h_latent = h_latent * (1 - params.tome_ratio)
def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):
if context is None:
context = x
cached_hs = cached_hiddens[idx].to(x.device)
d_model = x.shape[-1]
def attention_with_feedback(_x, context, feedback_hs, w):
num_xs, num_fb = _x.shape[0], feedback_hs.shape[0]
if num_fb > 0:
feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim)
merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens)
feedback_ctx = merge(feedback_ctx)
ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim)
weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,)
weights[_x.shape[1]:] = w
else:
cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0)
out = attn1._fabric_old_forward(x, **kwargs)
out = unmerge(out)
return out
def patched_attn2_forward(attn2, x, **kwargs):
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
x = merge(x)
out = attn2._fabric_old_forward(x, **kwargs)
out = unmerge(out)
return out
ctx = context
weights = None
return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim)
# patch forward pass to cache hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, transformer_block_type):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2)
layer_idx += 1
out = torch.zeros_like(x, dtype=devices.dtype_unet)
if num_cond > 0:
out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim)
out[cond_ids] = out_cond
if num_uncond > 0:
out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim)
out[uncond_ids] = out_uncond
return out
# run forward pass just to cache hidden states, output is discarded
for i in range(0, len(all_zs), batch_size):
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
# use the null prompt for pre-computing hidden states on feedback images
ctx_args = {}
if sd_version == SD15:
ctx_args["context"] = null_ctx.expand(zs.size(0), -1, -1) # (bs, seq_len, d_model)
else: # SDXL
ctx_args["context"] = null_ctx["crossattn"].expand(zs.size(0), -1, -1) # (bs, seq_len, d_model)
ctx_args["y"] = null_ctx["vector"].expand(zs.size(0), -1) # (bs, d_vector)
_ = self._fabric_old_forward(zs, ts, **ctx_args)
# patch forward pass to inject cached hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, transformer_block_type):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
layer_idx += 1
num_pos = len(pos_latents)
num_neg = len(neg_latents)
num_cond = len(cond_ids)
num_uncond = len(uncond_ids)
tome_h_latent = h_latent * (1 - params.tome_ratio)
# run forward pass with cached hidden states
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):
if context is None:
context = x
cond_outs = out[cond_ids]
uncond_outs = out[uncond_ids]
if has_cond:
prev_vals["cond"] = cond_outs.detach().clone()
if has_uncond:
prev_vals["uncond"] = uncond_outs.detach().clone()
cached_hs = cached_hiddens[idx].to(x.device)
d_model = x.shape[-1]
def attention_with_feedback(_x, context, feedback_hs, w):
num_xs, num_fb = _x.shape[0], feedback_hs.shape[0]
if num_fb > 0:
feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim)
merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens)
feedback_ctx = merge(feedback_ctx)
ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim)
weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,)
weights[_x.shape[1]:] = w
else:
ctx = context
weights = None
return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim)
out = torch.zeros_like(x, dtype=devices.unet_dtype)
if num_cond > 0:
out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim)
out[cond_ids] = out_cond
if num_uncond > 0:
out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim)
out[uncond_ids] = out_uncond
return out
# patch forward pass to inject cached hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, transformer_block_type):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
layer_idx += 1
# run forward pass with cached hidden states
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
cond_outs = out[cond_ids]
uncond_outs = out[uncond_ids]
xs = x[cond_ids]
t = int(timesteps[0].item())
seed = p.seed
x_means.append(xs.mean(dim=(2, 3)).view(-1).cpu().numpy())
x_stds.append(xs.std(dim=(2, 3)).view(-1).cpu().numpy())
x_norms.append(xs.norm(dim=(2, 3)).view(-1).cpu().numpy())
cond_means.append(cond_outs.mean(dim=(2, 3)).view(-1).cpu().numpy())
cond_stds.append(cond_outs.std(dim=(2, 3)).view(-1).cpu().numpy())
cond_norms.append(cond_outs.norm(dim=(2, 3)).view(-1).cpu().numpy())
uncond_means.append(uncond_outs.mean(dim=(2, 3)).view(-1).cpu().numpy())
uncond_stds.append(uncond_outs.std(dim=(2, 3)).view(-1).cpu().numpy())
uncond_norms.append(uncond_outs.norm(dim=(2, 3)).view(-1).cpu().numpy())
def plot_trajectory(means, stds, ax=None):
if ax is None:
ax = plt.gca()
# means = np.stack(means, axis=1)
# stds = np.stack(stds, axis=1)
ax.plot(means)
ax.fill_between(range(means.shape[0]), means - stds, means + stds, alpha=0.3)
def select(xs, batch_idx):
return np.stack([x[batch_idx] for x in xs], axis=0)
for i in range(len(xs)):
x_means_i = select(x_means, i)
x_stds_i = select(x_stds, i)
cond_means_i = select(cond_means, i)
cond_stds_i = select(cond_stds, i)
uncond_means_i = select(uncond_means, i)
uncond_stds_i = select(uncond_stds, i)
plot_file = f"plots/stats/{seed + i}_{'fabric' if DO_FABRIC else 'default'}.png"
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
plot_trajectory(x_means_i, x_stds_i, ax=axs[0])
plot_trajectory(cond_means_i, cond_stds_i, ax=axs[1])
plot_trajectory(uncond_means_i, uncond_stds_i, ax=axs[2])
fig.savefig(plot_file)
# os.makedirs("plots/lines", exist_ok=True)
# plot_lines(x_means, f"plots/lines/{seed}_x_means{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(x_stds, f"plots/lines/{seed}_x_stds{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(x_norms, f"plots/lines/{seed}_x_norms{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(cond_means, f"plots/lines/{seed}_cond_means{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(cond_stds, f"plots/lines/{seed}_cond_stds{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(cond_norms, f"plots/lines/{seed}_cond_norms{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(uncond_means, f"plots/lines/{seed}_uncond_means{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(uncond_stds, f"plots/lines/{seed}_uncond_stds{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(uncond_norms, f"plots/lines/{seed}_uncond_norms{'_fabric' if DO_FABRIC else ''}.png")
# os.makedirs("plots/hists", exist_ok=True)
# for i in range(len(xs)):
# filename = f"plots/hists/{seed}_{t}_x_{i}{'_fabric' if DO_FABRIC else ''}.png"
# plot_hist(x[i], filename)
# for i in range(len(cond_outs)):
# filename = f"plots/hists/{seed}_{t}_cond_{i}{'_fabric' if DO_FABRIC else ''}.png"
# plot_hist(cond_outs[i], filename)
# for i in range(len(uncond_outs)):
# filename = f"plots/hists/{seed}_{t}_uncond_{i}{'_fabric' if DO_FABRIC else ''}.png"
# plot_hist(uncond_outs[i], filename)
# ## intervention 1: mean decay
# out = out - 0.5*out.mean(dim=(2, 3), keepdim=True)
# ## intervention 4: early normalization
# relative_t = 1 - timesteps[0].item() / (p.sd_model.num_timesteps - 1)
# if relative_t < 0.33:
# out = (out - out.mean(dim=(2, 3), keepdim=True)) / out.std(dim=(2, 3), keepdim=True)
# ## intervention 5: momentum mean decay
# dampen = 0.8
# accel = 0.35
# alpha = 0.5
# beta = 0.25
# mean = out.mean(dim=(2, 3), keepdim=True)
# if "momentum" not in mean_ema:
# mean_ema["momentum"] = mean
# else:
# mean_ema["momentum"] = dampen * (mean_ema["momentum"] + accel * out.mean(dim=(2, 3), keepdim=True))
# out = out - alpha * mean - beta * mean_ema["momentum"]
# print()
# print("mean: ", mean.cpu().view(-1))
# print("momentum:", mean_ema["momentum"].cpu().view(-1))
# plot_lines([p.sd_model.betas.cpu().unsqueeze(0).numpy()], f"plots/lines/betas.png")
# plot_lines([p.sd_model.alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/alphas_cumprod.png")
# plot_lines([p.sd_model.alphas_cumprod_prev.cpu().unsqueeze(0).numpy()], f"plots/lines/alphas_cumprod_prev.png")
# plot_lines([p.sd_model.posterior_variance.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_variance.png")
# plot_lines([p.sd_model.posterior_log_variance_clipped.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_log_variance_clipped.png")
# plot_lines([p.sd_model.posterior_mean_coef1.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_mean_coef1.png")
# plot_lines([p.sd_model.posterior_mean_coef2.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_mean_coef2.png")
# plot_lines([p.sd_model.sqrt_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_alphas_cumprod.png")
# plot_lines([p.sd_model.sqrt_one_minus_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_one_minus_alphas_cumprod.png")
# plot_lines([p.sd_model.log_one_minus_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/log_one_minus_alphas_cumprod.png")
# plot_lines([p.sd_model.sqrt_recip_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_recip_alphas_cumprod.png")
# plot_lines([p.sd_model.sqrt_recipm1_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_recipm1_alphas_cumprod.png")
# ## intervention 8: dynamic standardization
# # TODO: test how to dynamically bound the mean rather than always subtracting some fraction
# relative_t = 1 - timesteps[0].item() / (p.sd_model.num_timesteps - 1)
# if relative_t < 0.8:
# alpha = 0.9
# sigma = (1 - relative_t)**0.1
# # print(p.n_iter, p.steps)
# # print(p.sampler.get_sigmas(p, p.steps))
# # sigmas = p.sampler.get_sigmas(p, p.steps)
# # print(torch.sqrt(1.0 + sigmas ** 2.0))
# # print(p.sd_model.alphas_cumprod)
# # sigma = p.sampler.get_sigmas(p, timesteps[0].item())
# std = out.std(dim=(2, 3), keepdim=True).clip(min=alpha*sigma, max=(2 - alpha)*sigma)
# mean = out.mean(dim=(2, 3), keepdim=True)
# out = (out - alpha*mean) / std
finally:
# restore original pass
for module in self.modules():
if isinstance(module, transformer_block_type) and hasattr(module.attn1, "_fabric_old_forward"):
module.attn1.forward = module.attn1._fabric_old_forward
del module.attn1._fabric_old_forward
if isinstance(module, transformer_block_type) and hasattr(module.attn2, "_fabric_old_forward"):
module.attn2.forward = module.attn2._fabric_old_forward
del module.attn2._fabric_old_forward
if params.burnout_protection:
# burnout protection: recenter the output to prevent instabilities caused by mean drift
mean = out.mean(dim=(2, 3), keepdim=True)
out = out - 0.5 * mean
finally:
# restore original pass
for module in self.modules():
if isinstance(module, transformer_block_type) and hasattr(module.attn1, "_fabric_old_forward"):
module.attn1.forward = module.attn1._fabric_old_forward
del module.attn1._fabric_old_forward
if isinstance(module, transformer_block_type) and hasattr(module.attn2, "_fabric_old_forward"):
module.attn2.forward = module.attn2._fabric_old_forward
del module.attn2._fabric_old_forward
return out