sd-webui-fabric/scripts/patching.py

105 lines
4.1 KiB
Python

import torch
from ldm.modules.attention import BasicTransformerBlock
def use_feedback(params):
if not params.enabled:
return False
if params.start >= params.end and params.min_weight <= 0:
return False
if params.max_weight <= 0:
return False
if params.neg_scale <= 0 and len(params.pos_latents) == 0:
return False
if len(params.pos_latents) == 0 and len(params.neg_latents) == 0:
return False
return True
def patch_unet_forward_pass(p, unet, params):
if not use_feedback(params):
return
if not hasattr(unet, "_fabric_old_forward"):
unet._fabric_old_forward = unet.forward
batch_size = p.batch_size
def new_forward(self, x, timesteps=None, context=None, **kwargs):
# save original forward pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1._fabric_old_forward = module.attn1.forward
# add noise to reference latents
all_zs = []
for latent in params.pos_latents + params.neg_latents:
z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0]
all_zs.append(z)
all_zs = torch.stack(all_zs, dim=0)
## cache hidden states
cached_hiddens = []
def patched_attn1_forward(attn1, x, **kwargs):
cached_hiddens.append(x.detach().clone())
out = attn1._fabric_old_forward(x, **kwargs)
return out
# patch forward pass to cache hidden states
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = patched_attn1_forward.__get__(module.attn1)
# run forward pass just to cache hidden states, output is discarded
all_zs = all_zs.to(x.device, dtype=self.dtype)
ts = timesteps[:1].expand(all_zs.size(0)) # (n_pos + n_neg,)
# TODO: instead of using the negative prompt, use the null prompt
ctx = context[batch_size:][:1].expand(all_zs.size(0), -1, -1) # (n_pos + n_neg, p_seq, p_dim)
_ = self._fabric_old_forward(all_zs, ts, ctx)
def patched_attn1_forward(attn1, x, context=None, **kwargs):
if context is None:
context = x
cached_hs = cached_hiddens.pop(0)
seq_len, d_model = x.shape[1:]
num_pos = len(params.pos_latents)
num_neg = len(params.neg_latents)
pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_pos, dim)
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_neg, dim)
x_cond = x[:batch_size] # (bs, seq, dim)
x_uncond = x[batch_size:] # (bs, seq, dim)
ctx_cond = torch.cat([context[:batch_size], pos_hs], dim=1) # (bs, seq * (1 + n_pos), dim)
ctx_uncond = torch.cat([context[batch_size:], neg_hs], dim=1) # (bs, seq * (1 + n_neg), dim)
out_cond = attn1._fabric_old_forward(x_cond, ctx_cond, **kwargs) # (bs, seq, dim)
out_uncond = attn1._fabric_old_forward(x_uncond, ctx_uncond, **kwargs) # (bs, seq, dim)
out = torch.cat([out_cond, out_uncond], dim=0)
return out
# patch forward pass to inject cached hidden states
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = patched_attn1_forward.__get__(module.attn1)
# run forward pass with cached hidden states
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
# restore original pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = module.attn1._fabric_old_forward
del module.attn1._fabric_old_forward
return out
unet.forward = new_forward.__get__(unet)
def unpatch_unet_forward_pass(unet):
if hasattr(unet, "_fabric_old_forward"):
unet.forward = unet._fabric_old_forward
del unet._fabric_old_forward