try fix #48, update controlnet core

main
Kahsolt 2023-12-29 23:35:11 +08:00
parent 71c40ed77e
commit 748119041e
2 changed files with 222 additions and 66 deletions

View File

@ -58,8 +58,9 @@ The latest version `v3.1` is synced & tested with:
⚪ Fixups
- 2023/07/05: sync sd-webui-controlnet to `v1.1.229`
- 2023/04/30: update controlnet core to `v1.1.116`
- 2023/12/29: fix bad ffmpeg envvar, update controlnet to `v1.1.424`
- 2023/07/05: update controlnet to `v1.1.229`
- 2023/04/30: update controlnet to `v1.1.116`
- 2023/03/29: `v2.4` bug fixes on script hook, now working correctly with extra networks & [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet)
- 2023/01/31: keep up with webui's updates, (issue #14: `ImportError: cannot import name 'single_sample_to_image'`)
- 2023/01/28: keep up with webui's updates, extra-networks rework

View File

@ -1,5 +1,5 @@
# This extension works with [Mikubill/sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet)
# version: v1.1.229
# version: v1.1.424
LOG_PREFIX = '[ControlNet-Travel]'
@ -31,6 +31,12 @@ if 'externel repo sanity check':
# ↑↑↑ EXIT EARLY IF EXTERNAL REPOSITORY NOT FOUND ↑↑↑
TOOL_PATH = ME_PATH / 'tools'
paths_ext = []
paths_ext.append(str(TOOL_PATH))
paths_ext.append(str(TOOL_PATH / 'rife-ncnn-vulkan'))
import os
os.environ['PATH'] += os.path.pathsep + os.path.pathsep.join(paths_ext)
import sys
from subprocess import Popen
@ -90,11 +96,13 @@ def run_cmd(cmd:str) -> bool:
# ↓↓↓ the following is modified from 'sd-webui-controlnet/scripts/hook.py' ↓↓↓
def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_params:List[ControlParams], process:Processing):
def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_params:List[ControlParams], process:Processing, batch_option_uint_separate=False, batch_option_style_align=False):
self.model = model
self.sd_ldm = sd_ldm
self.control_params = control_params
model_is_sdxl = getattr(self.sd_ldm, 'is_sdxl', False)
outer = self
def process_sample(*args, **kwargs):
@ -112,9 +120,13 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
return process.sample_before_CN_hack(*args, **kwargs)
# NOTE: ↓↓↓ only hack this method ↓↓↓
def forward(self:UNetModel, x:Tensor, timesteps:Tensor=None, context:Tensor=None, **kwargs):
total_controlnet_embedding = [0.0] * 13
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
is_sdxl = y is not None and model_is_sdxl
total_t2i_adapter_embedding = [0.0] * 4
if is_sdxl:
total_controlnet_embedding = [0.0] * 10
else:
total_controlnet_embedding = [0.0] * 13
require_inpaint_hijack = False
is_in_high_res_fix = False
batch_size = int(x.shape[0])
@ -127,19 +139,41 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
kwargs: dict # {}
# Handle cond-uncond marker
cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context)
cond_mark, outer.current_uc_indices, outer.current_c_indices, context = unmark_prompt_context(context)
outer.model.cond_mark = cond_mark
# logger.info(str(cond_mark[:, 0, 0, 0].detach().cpu().numpy().tolist()) + ' - ' + str(outer.current_uc_indices))
# Revision
if is_sdxl:
revision_y1280 = 0
for param in outer.control_params:
if param.guidance_stopped:
continue
if param.control_model_type == ControlModelType.ReVision:
if param.vision_hint_count is None:
k = torch.Tensor([int(param.preprocessor['threshold_a'] * 1000)]).to(param.hint_cond).long().clip(0, 999)
param.vision_hint_count = outer.revision_q_sampler.q_sample(param.hint_cond, k)
revision_emb = param.vision_hint_count
if isinstance(revision_emb, torch.Tensor):
revision_y1280 += revision_emb * param.weight
if isinstance(revision_y1280, torch.Tensor):
y[:, :1280] = revision_y1280 * cond_mark[:, :, 0, 0]
if any('ignore_prompt' in param.preprocessor['name'] for param in outer.control_params) \
or (getattr(process, 'prompt', '') == '' and getattr(process, 'negative_prompt', '') == ''):
context = torch.zeros_like(context)
# High-res fix
for param in outer.control_params:
# select which hint_cond to use
if param.used_hint_cond is None:
param.used_hint_cond = param.hint_cond # NOTE: input hint cond tensor, [1, 3, 512, 512]
param.used_hint_cond = param.hint_cond
param.used_hint_cond_latent = None
param.used_hint_inpaint_hijack = None
# has high-res fix
if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4:
if isinstance(param.hr_hint_cond, torch.Tensor) and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4:
_, _, h_lr, w_lr = param.hint_cond.shape
_, _, h_hr, w_hr = param.hr_hint_cond.shape
_, _, h, w = x.shape
@ -157,6 +191,10 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
param.used_hint_cond_latent = None
param.used_hint_inpaint_hijack = None
self.is_in_high_res_fix = is_in_high_res_fix
outer.is_in_high_res_fix = is_in_high_res_fix
no_high_res_control = is_in_high_res_fix and shared.opts.data.get("control_net_no_high_res_fix", False)
# NOTE: hint shallow fusion, overwrite param.used_hint_cond
for i, param in enumerate(outer.control_params):
if interp_alpha == 0.0: # collect hind_cond on key frames
@ -174,16 +212,31 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
and 'inpaint_only' not in param.preprocessor['name']:
continue
param.used_hint_cond_latent = outer.call_vae_using_process(process, param.used_hint_cond, batch_size=batch_size)
# vram
for param in outer.control_params:
if getattr(param.control_model, 'disable_memory_management', False):
continue
if param.control_model is not None:
if outer.lowvram and is_sdxl and hasattr(param.control_model, 'aggressive_lowvram'):
param.control_model.aggressive_lowvram()
elif hasattr(param.control_model, 'fullvram'):
param.control_model.fullvram()
elif hasattr(param.control_model, 'to'):
param.control_model.to(devices.get_device_for("controlnet"))
# handle prompt token control
for param in outer.control_params:
if no_high_res_control:
continue
if param.guidance_stopped:
continue
if param.control_model_type not in [ControlModelType.T2I_StyleAdapter]:
continue
param.control_model.to(devices.get_device_for("controlnet"))
control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context)
control = torch.cat([control.clone() for _ in range(batch_size)], dim=0)
control *= param.weight
@ -191,14 +244,16 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
context = torch.cat([context, control.clone()], dim=1)
# handle ControlNet / T2I_Adapter
for param in outer.control_params:
for param_index, param in enumerate(outer.control_params):
if no_high_res_control:
continue
if param.guidance_stopped:
continue
if param.control_model_type not in [ControlModelType.ControlNet, ControlModelType.T2I_Adapter]:
continue
param.control_model.to(devices.get_device_for("controlnet"))
# inpaint model workaround
x_in = x
control_model = param.control_model.control_model
@ -220,12 +275,12 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
m = (m > 0.5).float()
hint = c * (1 - m) - m
# NOTE: len(control) == 13, control[i]:Tensor
control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context)
control_scales = ([param.weight] * 13)
control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context, y=y)
if outer.lowvram:
param.control_model.to("cpu")
if is_sdxl:
control_scales = [param.weight] * 10
else:
control_scales = [param.weight] * 13
if param.cfg_injection or param.global_average_pooling:
if param.control_model_type == ControlModelType.T2I_Adapter:
@ -250,6 +305,9 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
elif param.control_model_type == ControlModelType.ControlNet:
control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)]
if is_sdxl and param.control_model_type == ControlModelType.ControlNet:
control_scales = control_scales[:10]
if param.advanced_weighting is not None:
control_scales = param.advanced_weighting
@ -264,10 +322,21 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
if param.control_model_type == ControlModelType.T2I_Adapter:
target = total_t2i_adapter_embedding
if target is not None:
target[idx] = item + target[idx]
if batch_option_uint_separate:
for pi, ci in enumerate(outer.current_c_indices):
if pi % len(outer.control_params) != param_index:
item[ci] = 0
for pi, ci in enumerate(outer.current_uc_indices):
if pi % len(outer.control_params) != param_index:
item[ci] = 0
target[idx] = item + target[idx]
else:
target[idx] = item + target[idx]
# Replace x_t to support inpaint models
for param in outer.control_params:
if not isinstance(param.used_hint_cond, torch.Tensor):
continue
if param.used_hint_cond.shape[1] != 4:
continue
if x.shape[1] != 9:
@ -284,8 +353,14 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
param.used_hint_inpaint_hijack.to(x.dtype).to(x.device)
x = torch.cat([x[:, :4, :, :], param.used_hint_inpaint_hijack], dim=1)
# vram
for param in outer.control_params:
if param.control_model is not None:
if outer.lowvram:
param.control_model.to('cpu')
# A1111 fix for medvram.
if shared.cmd_opts.medvram:
if shared.cmd_opts.medvram or (getattr(shared.cmd_opts, 'medvram_sdxl', False) and is_sdxl):
try:
# Trigger the register_forward_pre_hook
outer.sd_ldm.model()
@ -303,6 +378,9 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
# Handle attention and AdaIn control
for param in outer.control_params:
if no_high_res_control:
continue
if param.guidance_stopped:
continue
@ -312,7 +390,7 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
if param.control_model_type not in [ControlModelType.AttentionInjection]:
continue
ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long())
ref_xt = predict_q_sample(outer.sd_ldm, param.used_hint_cond_latent, torch.round(timesteps.float()).long())
# Inpaint Hijack
if x.shape[1] == 9:
@ -325,6 +403,12 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
outer.current_style_fidelity = float(param.preprocessor['threshold_a'])
outer.current_style_fidelity = max(0.0, min(1.0, outer.current_style_fidelity))
if is_sdxl:
# sdxl's attention hacking is highly unstable.
# We have no other methods but to reduce the style_fidelity a bit.
# By default, 0.5 ** 3.0 = 0.125
outer.current_style_fidelity = outer.current_style_fidelity ** 3.0
if param.cfg_injection:
outer.current_style_fidelity = 1.0
elif param.soft_injection or is_in_high_res_fix:
@ -340,11 +424,19 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
outer.gn_auto_machine = AutoMachine.Write
outer.gn_auto_machine_weight = param.weight
outer.original_forward(
x=ref_xt.to(devices.dtype_unet),
timesteps=timesteps.to(devices.dtype_unet),
context=context.to(devices.dtype_unet)
)
if is_sdxl:
outer.original_forward(
x=ref_xt.to(devices.dtype_unet),
timesteps=timesteps.to(devices.dtype_unet),
context=context.to(devices.dtype_unet),
y=y
)
else:
outer.original_forward(
x=ref_xt.to(devices.dtype_unet),
timesteps=timesteps.to(devices.dtype_unet),
context=context.to(devices.dtype_unet)
)
outer.attention_auto_machine = AutoMachine.Read
outer.gn_auto_machine = AutoMachine.Read
@ -377,21 +469,35 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
with th.no_grad():
t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
if is_sdxl:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x
for i, module in enumerate(self.input_blocks):
self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3])
h = module(h, emb, context)
if (i + 1) % 3 == 0:
t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11]
if i in t2i_injection:
h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack)
hs.append(h)
self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3])
h = self.middle_block(h, emb, context)
# U-Net Middle Block
h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack)
if len(total_t2i_adapter_embedding) > 0 and is_sdxl:
h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack)
# U-Net Decoder
for i, module in enumerate(self.output_blocks):
self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3])
h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1)
h = module(h, emb, context)
@ -407,7 +513,7 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
continue
k = int(param.preprocessor['threshold_a'])
if is_in_high_res_fix:
if is_in_high_res_fix and not no_high_res_control:
k *= 2
# Inpaint hijack
@ -454,18 +560,23 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
return h
def move_all_control_model_to_cpu():
for param in getattr(outer, 'control_params', []) or []:
if isinstance(param.control_model, torch.nn.Module):
param.control_model.to("cpu")
def forward_webui(*args, **kwargs):
# webui will handle other compoments
try:
if shared.cmd_opts.lowvram:
lowvram.send_everything_to_cpu()
return forward(*args, **kwargs)
except Exception as e:
move_all_control_model_to_cpu()
raise e
finally:
if self.lowvram:
for param in self.control_params:
if isinstance(param.control_model, torch.nn.Module):
param.control_model.to("cpu")
if outer.lowvram:
move_all_control_model_to_cpu()
def hacked_basic_transformer_inner_forward(self, x, context=None):
x_norm1 = self.norm1(x)
@ -492,6 +603,19 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
self.bank = []
self.style_cfgs = []
if outer.attention_auto_machine == AutoMachine.StyleAlign and not outer.is_in_high_res_fix:
# very VRAM hungry - disable at high_res_fix
def shared_attn1(inner_x):
BB, FF, CC = inner_x.shape
return self.attn1(inner_x.reshape(1, BB * FF, CC)).reshape(BB, FF, CC)
uc_layer = shared_attn1(x_norm1[outer.current_uc_indices])
c_layer = shared_attn1(x_norm1[outer.current_c_indices])
self_attn1 = torch.zeros_like(x_norm1).to(uc_layer)
self_attn1[outer.current_uc_indices] = uc_layer
self_attn1[outer.current_c_indices] = c_layer
del uc_layer, c_layer
if self_attn1 is None:
self_attn1 = self.attn1(x_norm1, context=self_attention_context)
@ -502,7 +626,7 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
def hacked_group_norm_forward(self, *args, **kwargs):
eps = 1e-6
x = self.original_forward(*args, **kwargs)
x = self.original_forward_cn_hijack(*args, **kwargs)
y = None
if outer.gn_auto_machine == AutoMachine.Write:
if outer.gn_auto_machine_weight > self.gn_weight:
@ -538,45 +662,76 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
outer.original_forward = model.forward
model.forward = forward_webui.__get__(model, UNetModel)
if model_is_sdxl:
register_schedule(sd_ldm)
outer.revision_q_sampler = AbstractLowScaleModel()
need_attention_hijack = False
for param in outer.control_params:
if param.control_model_type in [ControlModelType.AttentionInjection]:
need_attention_hijack = True
if batch_option_style_align:
need_attention_hijack = True
outer.attention_auto_machine = AutoMachine.StyleAlign
outer.gn_auto_machine = AutoMachine.StyleAlign
all_modules = torch_dfs(model)
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
if need_attention_hijack:
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlockSGM)]
attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
for i, module in enumerate(attn_modules):
if getattr(module, '_original_inner_forward', None) is None:
module._original_inner_forward = module._forward
module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
module.bank = []
module.style_cfgs = []
module.attn_weight = float(i) / float(len(attn_modules))
for i, module in enumerate(attn_modules):
if getattr(module, '_original_inner_forward_cn_hijack', None) is None:
module._original_inner_forward_cn_hijack = module._forward
module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
module.bank = []
module.style_cfgs = []
module.attn_weight = float(i) / float(len(attn_modules))
gn_modules = [model.middle_block]
model.middle_block.gn_weight = 0
gn_modules = [model.middle_block]
model.middle_block.gn_weight = 0
input_block_indices = [4, 5, 7, 8, 10, 11]
for w, i in enumerate(input_block_indices):
module = model.input_blocks[i]
module.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
gn_modules.append(module)
if model_is_sdxl:
input_block_indices = [4, 5, 7, 8]
output_block_indices = [0, 1, 2, 3, 4, 5]
else:
input_block_indices = [4, 5, 7, 8, 10, 11]
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
for w, i in enumerate(output_block_indices):
module = model.output_blocks[i]
module.gn_weight = float(w) / float(len(output_block_indices))
gn_modules.append(module)
for w, i in enumerate(input_block_indices):
module = model.input_blocks[i]
module.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
gn_modules.append(module)
for i, module in enumerate(gn_modules):
if getattr(module, 'original_forward', None) is None:
module.original_forward = module.forward
module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module)
module.mean_bank = []
module.var_bank = []
module.style_cfgs = []
module.gn_weight *= 2
for w, i in enumerate(output_block_indices):
module = model.output_blocks[i]
module.gn_weight = float(w) / float(len(output_block_indices))
gn_modules.append(module)
outer.attn_module_list = attn_modules
outer.gn_module_list = gn_modules
for i, module in enumerate(gn_modules):
if getattr(module, 'original_forward_cn_hijack', None) is None:
module.original_forward_cn_hijack = module.forward
module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module)
module.mean_bank = []
module.var_bank = []
module.style_cfgs = []
module.gn_weight *= 2
outer.attn_module_list = attn_modules
outer.gn_module_list = gn_modules
else:
for module in all_modules:
_original_inner_forward_cn_hijack = getattr(module, '_original_inner_forward_cn_hijack', None)
original_forward_cn_hijack = getattr(module, 'original_forward_cn_hijack', None)
if _original_inner_forward_cn_hijack is not None:
module._forward = _original_inner_forward_cn_hijack
if original_forward_cn_hijack is not None:
module.forward = original_forward_cn_hijack
outer.attn_module_list = []
outer.gn_module_list = []
scripts.script_callbacks.on_cfg_denoiser(self.guidance_schedule_handler)