sd-webui-fabric/scripts/patching.py

217 lines
8.9 KiB
Python

import functools
import torch
import torchvision.transforms.functional as functional
from modules import devices, images, shared
from modules.processing import StableDiffusionProcessingTxt2Img
from ldm.modules.attention import BasicTransformerBlock
from scripts.marking import patch_process_sample, unmark_prompt_context
from scripts.weighted_attention import weighted_attention
def encode_to_latent(p, image, w, h):
image = images.resize_image(1, image, w, h)
x = functional.pil_to_tensor(image)
x = functional.center_crop(x, (w, h)) # just to be safe
x = x.to(devices.device, dtype=devices.dtype_vae)
x = ((x / 255.0) * 2.0 - 1.0).unsqueeze(0)
# TODO: use caching to make this faster
with devices.autocast():
vae_output = p.sd_model.encode_first_stage(x)
z = p.sd_model.get_first_stage_encoding(vae_output)
return z.squeeze(0)
def get_latents_from_params(p, params, width, height):
w, h = (width // 8) * 8, (height // 8) * 8
w_latent, h_latent = width // 8, height // 8
def get_latents(images, cached_latents=None):
# check if latents need to be computed or recomputed (if image size changed e.g. due to high-res fix)
if cached_latents is None:
return [encode_to_latent(p, img, w, h) for img in images]
else:
ls = []
for latent, img in zip(cached_latents, images):
if latent.shape[-2:] != (w_latent, h_latent):
print(f"[FABRIC] Recomputing latent for image of size {img.size}")
latent = encode_to_latent(p, img, w, h)
ls.append(latent)
return ls
params.pos_latents = get_latents(params.pos_images, params.pos_latents)
params.neg_latents = get_latents(params.neg_images, params.neg_latents)
return params.pos_latents, params.neg_latents
def get_curr_feedback_weight(p, params, timestep):
progress = 1 - (timestep / (p.sd_model.num_timesteps - 1))
if progress >= params.start and progress <= params.end:
w = params.max_weight
else:
w = params.min_weight
return max(0, w), max(0, w * params.neg_scale)
def patch_unet_forward_pass(p, unet, params):
if not params.pos_images and not params.neg_images:
print("[FABRIC] No feedback images found, aborting patching")
return
if not hasattr(unet, "_fabric_old_forward"):
unet._fabric_old_forward = unet.forward
batch_size = p.batch_size
null_ctx = p.sd_model.get_learned_conditioning([""]).to(devices.device, dtype=devices.dtype_unet)
width = (p.width // 8) * 8
height = (p.height // 8) * 8
has_hires_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
if has_hires_fix:
if p.hr_resize_x == 0 and p.hr_resize_y == 0:
hr_w = int(p.width * p.hr_scale)
hr_h = int(p.height * p.hr_scale)
else:
hr_w, hr_h = p.hr_resize_x, p.hr_resize_y
hr_w = (hr_w // 8) * 8
hr_h = (hr_h // 8) * 8
else:
hr_w = width
hr_h = height
def new_forward(self, x, timesteps=None, context=None, **kwargs):
_, uncond_ids, context = unmark_prompt_context(context)
cond_ids = [i for i in range(context.size(0)) if i not in uncond_ids]
has_cond = len(cond_ids) > 0
has_uncond = len(uncond_ids) > 0
h_latent, w_latent = x.shape[-2:]
w, h = 8 * w_latent, 8 * h_latent
if has_hires_fix and w == hr_w and h == hr_h:
if not params.feedback_during_high_res_fix:
print("[FABRIC] Skipping feedback during high-res fix")
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item())
if pos_weight <= 0 and neg_weight <= 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_latents, neg_latents = get_latents_from_params(p, params, w, h)
pos_latents = pos_latents if has_cond else []
neg_latents = neg_latents if has_uncond else []
all_latents = pos_latents + neg_latents
if len(all_latents) == 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
# add noise to reference latents
all_zs = []
for latent in all_latents:
z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0]
all_zs.append(z)
all_zs = torch.stack(all_zs, dim=0)
# save original forward pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"):
module.attn1._fabric_old_forward = module.attn1.forward
# fix for medvram option
if shared.cmd_opts.medvram:
try:
# Trigger register_forward_pre_hook to move the model to correct device
p.sd_model.model()
except:
pass
## cache hidden states
cached_hiddens = {}
def patched_attn1_forward(attn1, idx, x, **kwargs):
if idx not in cached_hiddens:
cached_hiddens[idx] = x.detach().clone().cpu()
else:
cached_hiddens[idx] = torch.cat([cached_hiddens[idx], x.detach().clone().cpu()], dim=0)
out = attn1._fabric_old_forward(x, **kwargs)
return out
# patch forward pass to cache hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
layer_idx += 1
# run forward pass just to cache hidden states, output is discarded
for i in range(0, len(all_zs), batch_size):
zs = all_zs[i : i + batch_size].to(x.device, dtype=self.dtype)
ts = timesteps[:1].expand(zs.size(0)) # (bs,)
# use the null prompt for pre-computing hidden states on feedback images
ctx = null_ctx.expand(zs.size(0), -1, -1) # (bs, p_seq, p_dim)
_ = self._fabric_old_forward(zs, ts, ctx)
num_pos = len(pos_latents)
num_neg = len(neg_latents)
num_cond = len(cond_ids)
num_uncond = len(uncond_ids)
def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):
if context is None:
context = x
cached_hs = cached_hiddens[idx].to(x.device)
seq_len, d_model = x.shape[1:]
outs = []
if num_cond > 0:
pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(num_cond, -1, -1) # (n_cond, seq * n_pos, dim)
x_cond = x[cond_ids] # (n_cond, seq, dim)
ctx_cond = torch.cat([context[cond_ids], pos_hs], dim=1) # (n_cond, seq * (1 + n_pos), dim)
ws = torch.ones_like(ctx_cond[0, :, 0]) # (seq * (1 + n_pos),)
ws[x_cond.size(1):] = pos_weight
out_cond = weighted_attention(attn1, attn1._fabric_old_forward, x_cond, ctx_cond, ws, **kwargs) # (n_cond, seq, dim)
outs.append(out_cond)
if num_uncond > 0:
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(num_uncond, -1, -1) # (n_uncond, seq * n_neg, dim)
x_uncond = x[uncond_ids] # (n_uncond, seq, dim)
ctx_uncond = torch.cat([context[uncond_ids], neg_hs], dim=1) # (n_uncond, seq * (1 + n_neg), dim)
ws = torch.ones_like(ctx_uncond[0, :, 0]) # (seq * (1 + n_neg),)
ws[x_uncond.size(1):] = neg_weight
out_uncond = weighted_attention(attn1, attn1._fabric_old_forward, x_uncond, ctx_uncond, ws, **kwargs) # (n_uncond, seq, dim)
outs.append(out_uncond)
out = torch.cat(outs, dim=0)
return out
# patch forward pass to inject cached hidden states
layer_idx = 0
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
module.attn1.forward = functools.partial(patched_attn1_forward, module.attn1, layer_idx)
layer_idx += 1
# run forward pass with cached hidden states
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
# restore original pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"):
module.attn1.forward = module.attn1._fabric_old_forward
del module.attn1._fabric_old_forward
return out
unet.forward = new_forward.__get__(unet)
patch_process_sample(p)
def unpatch_unet_forward_pass(unet):
if hasattr(unet, "_fabric_old_forward"):
unet.forward = unet._fabric_old_forward
del unet._fabric_old_forward