add plotting and potential solutions

pull/38/head
dvruette 2024-03-07 10:38:33 +01:00
parent 5a247c9d9e
commit c0b88694c8
1 changed files with 307 additions and 124 deletions

View File

@ -1,7 +1,10 @@
import functools
import os
import torch
import torchvision.transforms.functional as functional
import matplotlib.pyplot as plt
import numpy as np
from modules import devices, images, shared
from modules.processing import StableDiffusionProcessingTxt2Img
@ -98,146 +101,326 @@ def patch_unet_forward_pass(p, unet, params):
"seed": params.tome_seed,
}
x_means = []
x_stds = []
x_norms = []
cond_means = []
cond_stds = []
cond_norms = []
uncond_means = []
uncond_stds = []
uncond_norms = []
DO_FABRIC = True
mean_ema = {}
def new_forward(self, x, timesteps=None, context=None, **kwargs):
_, uncond_ids, context = unmark_prompt_context(context)
cond_ids = [i for i in range(context.size(0)) if i not in uncond_ids]
has_cond = len(cond_ids) > 0
has_uncond = len(uncond_ids) > 0
h_latent, w_latent = x.shape[-2:]
w, h = 8 * w_latent, 8 * h_latent
if has_hires_fix and w == hr_w and h == hr_h:
if not params.feedback_during_high_res_fix:
print("[FABRIC] Skipping feedback during high-res fix")
def plot_hist(out, filename):
plt.figure()
xs = out.detach().cpu().numpy()
for i in range(xs.shape[0]):
hist, bin_edges = np.histogram(xs[i].reshape(-1), bins=100, density=True)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.
plt.plot(bin_centers, hist)
plt.savefig(filename)
plt.close()
def plot_lines(ys, filename):
plt.figure()
ys = np.stack(ys, axis=1)
for i in range(ys.shape[0]):
plt.plot(ys[i].reshape(-1))
plt.savefig(filename)
plt.close()
if not DO_FABRIC:
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
else:
h_latent, w_latent = x.shape[-2:]
w, h = 8 * w_latent, 8 * h_latent
if has_hires_fix and w == hr_w and h == hr_h:
if not params.feedback_during_high_res_fix:
print("[FABRIC] Skipping feedback during high-res fix")
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item())
if pos_weight <= 0 and neg_weight <= 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_latents, neg_latents = get_latents_from_params(p, params, w, h)
pos_latents = pos_latents if has_cond else []
neg_latents = neg_latents if has_uncond else []
all_latents = pos_latents + neg_latents
# Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU
if shared.cmd_opts.medvram:
try:
# Trigger register_forward_pre_hook to move the model to correct device
p.sd_model.model()
except:
pass
if len(all_latents) == 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item())
if pos_weight <= 0 and neg_weight <= 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_latents, neg_latents = get_latents_from_params(p, params, w, h)
pos_latents = pos_latents if has_cond else []
neg_latents = neg_latents if has_uncond else []
all_latents = pos_latents + neg_latents
# ## intervention 2: std decay
# std = x.std(dim=(2, 3), keepdim=True)
# mask = (x.std(dim=(2, 3), keepdim=True) > 1.0).float()
# x = mask * ((0.05*(std - 1) + 1) * x / std) + (1 - mask) * x
# ## intervention 3: mean clamp
# relative_t = timesteps[0].item() / (p.sd_model.num_timesteps - 1)
# max_mean = 0.5 * (1 - relative_t)
# mean = x.mean(dim=(2, 3), keepdim=True)
# x = x - mean + torch.clamp(mean, min=-max_mean, max=max_mean)
# ## intervention 7: mean decay
# mean = x.mean(dim=(2, 3), keepdim=True)
# x = x - 0.2 * mean
# add noise to reference latents
all_zs = []
for latent in all_latents:
z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0]
all_zs.append(z)
all_zs = torch.stack(all_zs, dim=0)
# save original forward pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"):
module.attn1._fabric_old_forward = module.attn1.forward
module.attn2._fabric_old_forward = module.attn2.forward
# Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU
if shared.cmd_opts.medvram:
try:
# Trigger register_forward_pre_hook to move the model to correct device
p.sd_model.model()
except:
pass
if len(all_latents) == 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
# add noise to reference latents
all_zs = []
for latent in all_latents:
z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0]
all_zs.append(z)
all_zs = torch.stack(all_zs, dim=0)
# save original forward pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"):
module.attn1._fabric_old_forward = module.attn1.forward
module.attn2._fabric_old_forward = module.attn2.forward
try:
## cache hidden states
cached_hiddens = {}
def patched_attn1_forward(attn1, layer_idx, x, **kwargs):
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
x = merge(x)
if layer_idx not in cached_hiddens:
cached_hiddens[layer_idx] = x.detach().clone().cpu()
else:
cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0)
out = attn1._fabric_old_forward(x, **kwargs)
out = unmerge(out)
return out
def patched_attn2_forward(attn2, x, **kwargs):
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
x = merge(x)
out = attn2._fabric_old_forward(x, **kwargs)
out = unmerge(out)
return out
# patch forward pass to cache hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2)
layer_idx += 1
# run forward pass just to cache hidden states, output is discarded
for i in range(0, len(all_zs), batch_size):
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
# use the null prompt for pre-computing hidden states on feedback images
ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim)
_ = self._fabric_old_forward(zs, ts, ctx)
num_pos = len(pos_latents)
num_neg = len(neg_latents)
num_cond = len(cond_ids)
num_uncond = len(uncond_ids)
tome_h_latent = h_latent * (1 - params.tome_ratio)
def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):
if context is None:
context = x
cached_hs = cached_hiddens[idx].to(x.device)
d_model = x.shape[-1]
def attention_with_feedback(_x, context, feedback_hs, w):
num_xs, num_fb = _x.shape[0], feedback_hs.shape[0]
if num_fb > 0:
feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim)
merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens)
feedback_ctx = merge(feedback_ctx)
ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim)
weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,)
weights[_x.shape[1]:] = w
## cache hidden states
cached_hiddens = {}
def patched_attn1_forward(attn1, layer_idx, x, **kwargs):
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
x = merge(x)
if layer_idx not in cached_hiddens:
cached_hiddens[layer_idx] = x.detach().clone().cpu()
else:
ctx = context
weights = None
return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim)
cached_hiddens[layer_idx] = torch.cat([cached_hiddens[layer_idx], x.detach().clone().cpu()], dim=0)
out = attn1._fabric_old_forward(x, **kwargs)
out = unmerge(out)
return out
def patched_attn2_forward(attn2, x, **kwargs):
merge, unmerge = compute_merge(x, args=tome_args, size=(h_latent, w_latent), ratio=params.tome_ratio)
x = merge(x)
out = attn2._fabric_old_forward(x, **kwargs)
out = unmerge(out)
return out
outs = []
if num_cond > 0:
out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim)
outs.append(out_cond)
if num_uncond > 0:
out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim)
outs.append(out_uncond)
out = torch.cat(outs, dim=0)
return out
# patch forward pass to cache hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
module.attn2.forward = functools.partial(patched_attn2_forward, module.attn2)
layer_idx += 1
# patch forward pass to inject cached hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
layer_idx += 1
# run forward pass just to cache hidden states, output is discarded
for i in range(0, len(all_zs), batch_size):
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
# use the null prompt for pre-computing hidden states on feedback images
ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim)
_ = self._fabric_old_forward(zs, ts, ctx)
# run forward pass with cached hidden states
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
num_pos = len(pos_latents)
num_neg = len(neg_latents)
num_cond = len(cond_ids)
num_uncond = len(uncond_ids)
tome_h_latent = h_latent * (1 - params.tome_ratio)
finally:
# restore original pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"):
module.attn1.forward = module.attn1._fabric_old_forward
del module.attn1._fabric_old_forward
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn2, "_fabric_old_forward"):
module.attn2.forward = module.attn2._fabric_old_forward
del module.attn2._fabric_old_forward
def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):
if context is None:
context = x
cached_hs = cached_hiddens[idx].to(x.device)
d_model = x.shape[-1]
def attention_with_feedback(_x, context, feedback_hs, w):
num_xs, num_fb = _x.shape[0], feedback_hs.shape[0]
if num_fb > 0:
feedback_ctx = feedback_hs.view(1, -1, d_model).expand(num_xs, -1, -1) # (n_cond, seq * n_pos, dim)
merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens)
feedback_ctx = merge(feedback_ctx)
ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim)
weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,)
weights[_x.shape[1]:] = w
else:
ctx = context
weights = None
return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim)
outs = []
if num_cond > 0:
out_cond = attention_with_feedback(x[cond_ids], context[cond_ids], cached_hs[:num_pos], pos_weight) # (n_cond, seq, dim)
outs.append(out_cond)
if num_uncond > 0:
out_uncond = attention_with_feedback(x[uncond_ids], context[uncond_ids], cached_hs[num_pos:], neg_weight) # (n_cond, seq, dim)
outs.append(out_uncond)
out = torch.cat(outs, dim=0)
return out
# patch forward pass to inject cached hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
layer_idx += 1
# run forward pass with cached hidden states
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
finally:
# restore original pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"):
module.attn1.forward = module.attn1._fabric_old_forward
del module.attn1._fabric_old_forward
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn2, "_fabric_old_forward"):
module.attn2.forward = module.attn2._fabric_old_forward
del module.attn2._fabric_old_forward
cond_outs = out[cond_ids]
uncond_outs = out[uncond_ids]
xs = x[cond_ids]
t = int(timesteps[0].item())
seed = p.seed
x_means.append(xs.mean(dim=(2, 3)).view(-1).cpu().numpy())
x_stds.append(xs.std(dim=(2, 3)).view(-1).cpu().numpy())
x_norms.append(xs.norm(dim=(2, 3)).view(-1).cpu().numpy())
cond_means.append(cond_outs.mean(dim=(2, 3)).view(-1).cpu().numpy())
cond_stds.append(cond_outs.std(dim=(2, 3)).view(-1).cpu().numpy())
cond_norms.append(cond_outs.norm(dim=(2, 3)).view(-1).cpu().numpy())
uncond_means.append(uncond_outs.mean(dim=(2, 3)).view(-1).cpu().numpy())
uncond_stds.append(uncond_outs.std(dim=(2, 3)).view(-1).cpu().numpy())
uncond_norms.append(uncond_outs.norm(dim=(2, 3)).view(-1).cpu().numpy())
def plot_trajectory(means, stds, ax=None):
if ax is None:
ax = plt.gca()
# means = np.stack(means, axis=1)
# stds = np.stack(stds, axis=1)
ax.plot(means)
ax.fill_between(range(means.shape[0]), means - stds, means + stds, alpha=0.3)
def select(xs, batch_idx):
return np.stack([x[batch_idx] for x in xs], axis=0)
for i in range(len(xs)):
x_means_i = select(x_means, i)
x_stds_i = select(x_stds, i)
cond_means_i = select(cond_means, i)
cond_stds_i = select(cond_stds, i)
uncond_means_i = select(uncond_means, i)
uncond_stds_i = select(uncond_stds, i)
plot_file = f"plots/stats/{seed + i}_{'fabric' if DO_FABRIC else 'default'}.png"
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
plot_trajectory(x_means_i, x_stds_i, ax=axs[0])
plot_trajectory(cond_means_i, cond_stds_i, ax=axs[1])
plot_trajectory(uncond_means_i, uncond_stds_i, ax=axs[2])
fig.savefig(plot_file)
# os.makedirs("plots/lines", exist_ok=True)
# plot_lines(x_means, f"plots/lines/{seed}_x_means{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(x_stds, f"plots/lines/{seed}_x_stds{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(x_norms, f"plots/lines/{seed}_x_norms{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(cond_means, f"plots/lines/{seed}_cond_means{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(cond_stds, f"plots/lines/{seed}_cond_stds{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(cond_norms, f"plots/lines/{seed}_cond_norms{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(uncond_means, f"plots/lines/{seed}_uncond_means{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(uncond_stds, f"plots/lines/{seed}_uncond_stds{'_fabric' if DO_FABRIC else ''}.png")
# plot_lines(uncond_norms, f"plots/lines/{seed}_uncond_norms{'_fabric' if DO_FABRIC else ''}.png")
# os.makedirs("plots/hists", exist_ok=True)
# for i in range(len(xs)):
# filename = f"plots/hists/{seed}_{t}_x_{i}{'_fabric' if DO_FABRIC else ''}.png"
# plot_hist(x[i], filename)
# for i in range(len(cond_outs)):
# filename = f"plots/hists/{seed}_{t}_cond_{i}{'_fabric' if DO_FABRIC else ''}.png"
# plot_hist(cond_outs[i], filename)
# for i in range(len(uncond_outs)):
# filename = f"plots/hists/{seed}_{t}_uncond_{i}{'_fabric' if DO_FABRIC else ''}.png"
# plot_hist(uncond_outs[i], filename)
# ## intervention 1: mean decay
# out = out - 0.5*out.mean(dim=(2, 3), keepdim=True)
# ## intervention 4: early normalization
# relative_t = 1 - timesteps[0].item() / (p.sd_model.num_timesteps - 1)
# if relative_t < 0.33:
# out = (out - out.mean(dim=(2, 3), keepdim=True)) / out.std(dim=(2, 3), keepdim=True)
# ## intervention 5: momentum mean decay
# dampen = 0.8
# accel = 0.35
# alpha = 0.5
# beta = 0.25
# mean = out.mean(dim=(2, 3), keepdim=True)
# if "momentum" not in mean_ema:
# mean_ema["momentum"] = mean
# else:
# mean_ema["momentum"] = dampen * (mean_ema["momentum"] + accel * out.mean(dim=(2, 3), keepdim=True))
# out = out - alpha * mean - beta * mean_ema["momentum"]
# print()
# print("mean: ", mean.cpu().view(-1))
# print("momentum:", mean_ema["momentum"].cpu().view(-1))
# plot_lines([p.sd_model.betas.cpu().unsqueeze(0).numpy()], f"plots/lines/betas.png")
# plot_lines([p.sd_model.alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/alphas_cumprod.png")
# plot_lines([p.sd_model.alphas_cumprod_prev.cpu().unsqueeze(0).numpy()], f"plots/lines/alphas_cumprod_prev.png")
# plot_lines([p.sd_model.posterior_variance.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_variance.png")
# plot_lines([p.sd_model.posterior_log_variance_clipped.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_log_variance_clipped.png")
# plot_lines([p.sd_model.posterior_mean_coef1.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_mean_coef1.png")
# plot_lines([p.sd_model.posterior_mean_coef2.cpu().unsqueeze(0).numpy()], f"plots/lines/posterior_mean_coef2.png")
# plot_lines([p.sd_model.sqrt_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_alphas_cumprod.png")
# plot_lines([p.sd_model.sqrt_one_minus_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_one_minus_alphas_cumprod.png")
# plot_lines([p.sd_model.log_one_minus_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/log_one_minus_alphas_cumprod.png")
# plot_lines([p.sd_model.sqrt_recip_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_recip_alphas_cumprod.png")
# plot_lines([p.sd_model.sqrt_recipm1_alphas_cumprod.cpu().unsqueeze(0).numpy()], f"plots/lines/sqrt_recipm1_alphas_cumprod.png")
# ## intervention 8: dynamic standardization
# # TODO: test how to dynamically bound the mean rather than always subtracting some fraction
# relative_t = 1 - timesteps[0].item() / (p.sd_model.num_timesteps - 1)
# if relative_t < 0.8:
# alpha = 0.9
# sigma = (1 - relative_t)**0.1
# # print(p.n_iter, p.steps)
# # print(p.sampler.get_sigmas(p, p.steps))
# # sigmas = p.sampler.get_sigmas(p, p.steps)
# # print(torch.sqrt(1.0 + sigmas ** 2.0))
# # print(p.sd_model.alphas_cumprod)
# # sigma = p.sampler.get_sigmas(p, timesteps[0].item())
# std = out.std(dim=(2, 3), keepdim=True).clip(min=alpha*sigma, max=(2 - alpha)*sigma)
# mean = out.mean(dim=(2, 3), keepdim=True)
# out = (out - alpha*mean) / std
return out