212 lines
8.7 KiB
Python
212 lines
8.7 KiB
Python
|
|
import torch
|
|
import torch.nn as nn
|
|
from modules import devices, lowvram, shared, scripts
|
|
|
|
cond_cast_unet = getattr(devices, 'cond_cast_unet', lambda x: x)
|
|
|
|
from ldm.modules.diffusionmodules.util import timestep_embedding
|
|
from ldm.modules.diffusionmodules.openaimodel import UNetModel
|
|
|
|
|
|
class TorchHijackForUnet:
|
|
"""
|
|
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
|
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
|
"""
|
|
|
|
def __getattr__(self, item):
|
|
if item == 'cat':
|
|
return self.cat
|
|
|
|
if hasattr(torch, item):
|
|
return getattr(torch, item)
|
|
|
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
|
|
|
def cat(self, tensors, *args, **kwargs):
|
|
if len(tensors) == 2:
|
|
a, b = tensors
|
|
if a.shape[-2:] != b.shape[-2:]:
|
|
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
|
|
|
tensors = (a, b)
|
|
|
|
return torch.cat(tensors, *args, **kwargs)
|
|
|
|
|
|
th = TorchHijackForUnet()
|
|
|
|
|
|
class ControlParams:
|
|
def __init__(
|
|
self,
|
|
control_model,
|
|
hint_cond,
|
|
guess_mode,
|
|
weight,
|
|
guidance_stopped,
|
|
stop_guidance_percent,
|
|
advanced_weighting,
|
|
is_adapter
|
|
):
|
|
self.control_model = control_model
|
|
self.hint_cond = hint_cond
|
|
self.guess_mode = guess_mode
|
|
self.weight = weight
|
|
self.guidance_stopped = guidance_stopped
|
|
self.stop_guidance_percent = stop_guidance_percent
|
|
self.advanced_weighting = advanced_weighting
|
|
self.is_adapter = is_adapter
|
|
|
|
|
|
class UnetHook(nn.Module):
|
|
def __init__(self, lowvram=False) -> None:
|
|
super().__init__()
|
|
self.lowvram = lowvram
|
|
self.only_mid_control = shared.opts.data.get("control_net_only_mid_control", False)
|
|
|
|
def hook(self, model):
|
|
outer = self
|
|
|
|
def guidance_schedule_handler(x):
|
|
for param in self.control_params:
|
|
param.guidance_stopped = (x.sampling_step / x.total_sampling_steps) > param.stop_guidance_percent
|
|
|
|
def cfg_based_adder(base, x, require_autocast, is_adapter=False):
|
|
if isinstance(x, float):
|
|
return base + x
|
|
|
|
if require_autocast:
|
|
zeros = torch.zeros_like(base)
|
|
zeros[:, :x.shape[1], ...] = x
|
|
x = zeros
|
|
|
|
# assume the input format is [cond, uncond] and they have same shape
|
|
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/0cc0ee1bcb4c24a8c9715f66cede06601bfc00c8/modules/sd_samplers_kdiffusion.py#L114
|
|
if base.shape[0] % 2 == 0 and (self.guess_mode or shared.opts.data.get("control_net_cfg_based_guidance", False)):
|
|
if self.is_vanilla_samplers:
|
|
uncond, cond = base.chunk(2)
|
|
if x.shape[0] % 2 == 0:
|
|
_, x_cond = x.chunk(2)
|
|
return torch.cat([uncond, cond + x_cond], dim=0)
|
|
if is_adapter:
|
|
return torch.cat([uncond, cond + x], dim=0)
|
|
else:
|
|
cond, uncond = base.chunk(2)
|
|
if x.shape[0] % 2 == 0:
|
|
x_cond, _ = x.chunk(2)
|
|
return torch.cat([cond + x_cond, uncond], dim=0)
|
|
if is_adapter:
|
|
return torch.cat([cond + x, uncond], dim=0)
|
|
|
|
return base + x
|
|
|
|
def forward(self, x, timesteps=None, context=None, **kwargs):
|
|
total_control = [0.0] * 13
|
|
total_adapter = [0.0] * 4
|
|
only_mid_control = outer.only_mid_control
|
|
require_inpaint_hijack = False
|
|
|
|
for param in outer.control_params:
|
|
if param.guidance_stopped:
|
|
continue
|
|
if outer.lowvram:
|
|
param.control_model.to(devices.get_device_for("controlnet"))
|
|
|
|
# hires stuffs
|
|
# note that this method may not works if hr_scale < 1.1
|
|
if abs(x.shape[-1] - param.hint_cond.shape[-1] // 8) > 8:
|
|
only_mid_control = shared.opts.data.get("control_net_only_midctrl_hires", True)
|
|
# If you want to completely disable control net, uncomment this.
|
|
# return self._original_forward(x, timesteps=timesteps, context=context, **kwargs)
|
|
|
|
# inpaint model workaround
|
|
x_in = x
|
|
control_model = param.control_model.control_model
|
|
if not param.is_adapter and x.shape[1] != control_model.input_blocks[0][0].in_channels and x.shape[1] == 9:
|
|
# inpaint_model: 4 data + 4 downscaled image + 1 mask
|
|
x_in = x[:, :4, ...]
|
|
require_inpaint_hijack = True
|
|
|
|
assert param.hint_cond is not None, f"Controlnet is enabled but no input image is given"
|
|
control = param.control_model(x=x_in, hint=param.hint_cond, timesteps=timesteps, context=context)
|
|
control_scales = ([param.weight] * 13)
|
|
|
|
if outer.lowvram:
|
|
param.control_model.to("cpu")
|
|
if param.guess_mode:
|
|
if param.is_adapter:
|
|
# see https://github.com/Mikubill/sd-webui-controlnet/issues/269
|
|
control_scales = param.weight * [0.25, 0.62, 0.825, 1.0]
|
|
else:
|
|
control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)]
|
|
if param.advanced_weighting is not None:
|
|
control_scales = param.advanced_weighting
|
|
|
|
control = [c * scale for c, scale in zip(control, control_scales)]
|
|
for idx, item in enumerate(control):
|
|
target = total_adapter if param.is_adapter else total_control
|
|
target[idx] += item
|
|
|
|
control = total_control
|
|
assert timesteps is not None, ValueError(f"insufficient timestep: {timesteps}")
|
|
hs = []
|
|
with th.no_grad():
|
|
t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
|
|
emb = self.time_embed(t_emb)
|
|
h = x.type(self.dtype)
|
|
for i, module in enumerate(self.input_blocks):
|
|
h = module(h, emb, context)
|
|
|
|
# t2i-adatper, same as openaimodel.py:744
|
|
if ((i+1)%3 == 0) and len(total_adapter):
|
|
h = cfg_based_adder(h, total_adapter.pop(0), require_inpaint_hijack, is_adapter=True)
|
|
|
|
hs.append(h)
|
|
h = self.middle_block(h, emb, context)
|
|
|
|
control_in = control.pop()
|
|
h = cfg_based_adder(h, control_in, require_inpaint_hijack)
|
|
|
|
for i, module in enumerate(self.output_blocks):
|
|
if only_mid_control:
|
|
hs_input = hs.pop()
|
|
h = th.cat([h, hs_input], dim=1)
|
|
else:
|
|
hs_input, control_input = hs.pop(), control.pop()
|
|
h = th.cat([h, cfg_based_adder(hs_input, control_input, require_inpaint_hijack)], dim=1)
|
|
h = module(h, emb, context)
|
|
|
|
h = h.type(x.dtype)
|
|
return self.out(h)
|
|
|
|
def forward2(*args, **kwargs):
|
|
# webui will handle other compoments
|
|
try:
|
|
if shared.cmd_opts.lowvram:
|
|
lowvram.send_everything_to_cpu()
|
|
|
|
return forward(*args, **kwargs)
|
|
finally:
|
|
if self.lowvram:
|
|
[param.control_model.to("cpu") for param in self.control_params]
|
|
|
|
model._original_forward = model.forward
|
|
model.forward = forward2.__get__(model, UNetModel)
|
|
scripts.script_callbacks.on_cfg_denoiser(guidance_schedule_handler)
|
|
|
|
def notify(self, params, is_vanilla_samplers): # lint: list[ControlParams]
|
|
self.is_vanilla_samplers = is_vanilla_samplers
|
|
self.control_params = params
|
|
self.guess_mode = any([param.guess_mode for param in params])
|
|
|
|
def restore(self, model):
|
|
scripts.script_callbacks.remove_current_script_callbacks()
|
|
if not hasattr(model, "_original_forward"):
|
|
# no such handle, ignore
|
|
return
|
|
|
|
model.forward = model._original_forward
|
|
del model._original_forward
|