try fix #48, update controlnet core
parent
71c40ed77e
commit
748119041e
|
|
@ -58,8 +58,9 @@ The latest version `v3.1` is synced & tested with:
|
|||
|
||||
⚪ Fixups
|
||||
|
||||
- 2023/07/05: sync sd-webui-controlnet to `v1.1.229`
|
||||
- 2023/04/30: update controlnet core to `v1.1.116`
|
||||
- 2023/12/29: fix bad ffmpeg envvar, update controlnet to `v1.1.424`
|
||||
- 2023/07/05: update controlnet to `v1.1.229`
|
||||
- 2023/04/30: update controlnet to `v1.1.116`
|
||||
- 2023/03/29: `v2.4` bug fixes on script hook, now working correctly with extra networks & [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet)
|
||||
- 2023/01/31: keep up with webui's updates, (issue #14: `ImportError: cannot import name 'single_sample_to_image'`)
|
||||
- 2023/01/28: keep up with webui's updates, extra-networks rework
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# This extension works with [Mikubill/sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet)
|
||||
# version: v1.1.229
|
||||
# version: v1.1.424
|
||||
|
||||
LOG_PREFIX = '[ControlNet-Travel]'
|
||||
|
||||
|
|
@ -31,6 +31,12 @@ if 'externel repo sanity check':
|
|||
|
||||
# ↑↑↑ EXIT EARLY IF EXTERNAL REPOSITORY NOT FOUND ↑↑↑
|
||||
|
||||
TOOL_PATH = ME_PATH / 'tools'
|
||||
paths_ext = []
|
||||
paths_ext.append(str(TOOL_PATH))
|
||||
paths_ext.append(str(TOOL_PATH / 'rife-ncnn-vulkan'))
|
||||
import os
|
||||
os.environ['PATH'] += os.path.pathsep + os.path.pathsep.join(paths_ext)
|
||||
|
||||
import sys
|
||||
from subprocess import Popen
|
||||
|
|
@ -90,11 +96,13 @@ def run_cmd(cmd:str) -> bool:
|
|||
|
||||
# ↓↓↓ the following is modified from 'sd-webui-controlnet/scripts/hook.py' ↓↓↓
|
||||
|
||||
def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_params:List[ControlParams], process:Processing):
|
||||
def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_params:List[ControlParams], process:Processing, batch_option_uint_separate=False, batch_option_style_align=False):
|
||||
self.model = model
|
||||
self.sd_ldm = sd_ldm
|
||||
self.control_params = control_params
|
||||
|
||||
model_is_sdxl = getattr(self.sd_ldm, 'is_sdxl', False)
|
||||
|
||||
outer = self
|
||||
|
||||
def process_sample(*args, **kwargs):
|
||||
|
|
@ -112,9 +120,13 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
return process.sample_before_CN_hack(*args, **kwargs)
|
||||
|
||||
# NOTE: ↓↓↓ only hack this method ↓↓↓
|
||||
def forward(self:UNetModel, x:Tensor, timesteps:Tensor=None, context:Tensor=None, **kwargs):
|
||||
total_controlnet_embedding = [0.0] * 13
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
is_sdxl = y is not None and model_is_sdxl
|
||||
total_t2i_adapter_embedding = [0.0] * 4
|
||||
if is_sdxl:
|
||||
total_controlnet_embedding = [0.0] * 10
|
||||
else:
|
||||
total_controlnet_embedding = [0.0] * 13
|
||||
require_inpaint_hijack = False
|
||||
is_in_high_res_fix = False
|
||||
batch_size = int(x.shape[0])
|
||||
|
|
@ -127,19 +139,41 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
kwargs: dict # {}
|
||||
|
||||
# Handle cond-uncond marker
|
||||
cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context)
|
||||
cond_mark, outer.current_uc_indices, outer.current_c_indices, context = unmark_prompt_context(context)
|
||||
outer.model.cond_mark = cond_mark
|
||||
# logger.info(str(cond_mark[:, 0, 0, 0].detach().cpu().numpy().tolist()) + ' - ' + str(outer.current_uc_indices))
|
||||
|
||||
# Revision
|
||||
if is_sdxl:
|
||||
revision_y1280 = 0
|
||||
|
||||
for param in outer.control_params:
|
||||
if param.guidance_stopped:
|
||||
continue
|
||||
if param.control_model_type == ControlModelType.ReVision:
|
||||
if param.vision_hint_count is None:
|
||||
k = torch.Tensor([int(param.preprocessor['threshold_a'] * 1000)]).to(param.hint_cond).long().clip(0, 999)
|
||||
param.vision_hint_count = outer.revision_q_sampler.q_sample(param.hint_cond, k)
|
||||
revision_emb = param.vision_hint_count
|
||||
if isinstance(revision_emb, torch.Tensor):
|
||||
revision_y1280 += revision_emb * param.weight
|
||||
|
||||
if isinstance(revision_y1280, torch.Tensor):
|
||||
y[:, :1280] = revision_y1280 * cond_mark[:, :, 0, 0]
|
||||
if any('ignore_prompt' in param.preprocessor['name'] for param in outer.control_params) \
|
||||
or (getattr(process, 'prompt', '') == '' and getattr(process, 'negative_prompt', '') == ''):
|
||||
context = torch.zeros_like(context)
|
||||
|
||||
# High-res fix
|
||||
for param in outer.control_params:
|
||||
# select which hint_cond to use
|
||||
if param.used_hint_cond is None:
|
||||
param.used_hint_cond = param.hint_cond # NOTE: input hint cond tensor, [1, 3, 512, 512]
|
||||
param.used_hint_cond = param.hint_cond
|
||||
param.used_hint_cond_latent = None
|
||||
param.used_hint_inpaint_hijack = None
|
||||
|
||||
# has high-res fix
|
||||
if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4:
|
||||
if isinstance(param.hr_hint_cond, torch.Tensor) and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4:
|
||||
_, _, h_lr, w_lr = param.hint_cond.shape
|
||||
_, _, h_hr, w_hr = param.hr_hint_cond.shape
|
||||
_, _, h, w = x.shape
|
||||
|
|
@ -157,6 +191,10 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
param.used_hint_cond_latent = None
|
||||
param.used_hint_inpaint_hijack = None
|
||||
|
||||
self.is_in_high_res_fix = is_in_high_res_fix
|
||||
outer.is_in_high_res_fix = is_in_high_res_fix
|
||||
no_high_res_control = is_in_high_res_fix and shared.opts.data.get("control_net_no_high_res_fix", False)
|
||||
|
||||
# NOTE: hint shallow fusion, overwrite param.used_hint_cond
|
||||
for i, param in enumerate(outer.control_params):
|
||||
if interp_alpha == 0.0: # collect hind_cond on key frames
|
||||
|
|
@ -174,16 +212,31 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
and 'inpaint_only' not in param.preprocessor['name']:
|
||||
continue
|
||||
param.used_hint_cond_latent = outer.call_vae_using_process(process, param.used_hint_cond, batch_size=batch_size)
|
||||
|
||||
|
||||
# vram
|
||||
for param in outer.control_params:
|
||||
if getattr(param.control_model, 'disable_memory_management', False):
|
||||
continue
|
||||
|
||||
if param.control_model is not None:
|
||||
if outer.lowvram and is_sdxl and hasattr(param.control_model, 'aggressive_lowvram'):
|
||||
param.control_model.aggressive_lowvram()
|
||||
elif hasattr(param.control_model, 'fullvram'):
|
||||
param.control_model.fullvram()
|
||||
elif hasattr(param.control_model, 'to'):
|
||||
param.control_model.to(devices.get_device_for("controlnet"))
|
||||
|
||||
# handle prompt token control
|
||||
for param in outer.control_params:
|
||||
if no_high_res_control:
|
||||
continue
|
||||
|
||||
if param.guidance_stopped:
|
||||
continue
|
||||
|
||||
if param.control_model_type not in [ControlModelType.T2I_StyleAdapter]:
|
||||
continue
|
||||
|
||||
param.control_model.to(devices.get_device_for("controlnet"))
|
||||
control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context)
|
||||
control = torch.cat([control.clone() for _ in range(batch_size)], dim=0)
|
||||
control *= param.weight
|
||||
|
|
@ -191,14 +244,16 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
context = torch.cat([context, control.clone()], dim=1)
|
||||
|
||||
# handle ControlNet / T2I_Adapter
|
||||
for param in outer.control_params:
|
||||
for param_index, param in enumerate(outer.control_params):
|
||||
if no_high_res_control:
|
||||
continue
|
||||
|
||||
if param.guidance_stopped:
|
||||
continue
|
||||
|
||||
if param.control_model_type not in [ControlModelType.ControlNet, ControlModelType.T2I_Adapter]:
|
||||
continue
|
||||
|
||||
param.control_model.to(devices.get_device_for("controlnet"))
|
||||
# inpaint model workaround
|
||||
x_in = x
|
||||
control_model = param.control_model.control_model
|
||||
|
|
@ -220,12 +275,12 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
m = (m > 0.5).float()
|
||||
hint = c * (1 - m) - m
|
||||
|
||||
# NOTE: len(control) == 13, control[i]:Tensor
|
||||
control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context)
|
||||
control_scales = ([param.weight] * 13)
|
||||
control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context, y=y)
|
||||
|
||||
if outer.lowvram:
|
||||
param.control_model.to("cpu")
|
||||
if is_sdxl:
|
||||
control_scales = [param.weight] * 10
|
||||
else:
|
||||
control_scales = [param.weight] * 13
|
||||
|
||||
if param.cfg_injection or param.global_average_pooling:
|
||||
if param.control_model_type == ControlModelType.T2I_Adapter:
|
||||
|
|
@ -250,6 +305,9 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
elif param.control_model_type == ControlModelType.ControlNet:
|
||||
control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)]
|
||||
|
||||
if is_sdxl and param.control_model_type == ControlModelType.ControlNet:
|
||||
control_scales = control_scales[:10]
|
||||
|
||||
if param.advanced_weighting is not None:
|
||||
control_scales = param.advanced_weighting
|
||||
|
||||
|
|
@ -264,10 +322,21 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
if param.control_model_type == ControlModelType.T2I_Adapter:
|
||||
target = total_t2i_adapter_embedding
|
||||
if target is not None:
|
||||
target[idx] = item + target[idx]
|
||||
if batch_option_uint_separate:
|
||||
for pi, ci in enumerate(outer.current_c_indices):
|
||||
if pi % len(outer.control_params) != param_index:
|
||||
item[ci] = 0
|
||||
for pi, ci in enumerate(outer.current_uc_indices):
|
||||
if pi % len(outer.control_params) != param_index:
|
||||
item[ci] = 0
|
||||
target[idx] = item + target[idx]
|
||||
else:
|
||||
target[idx] = item + target[idx]
|
||||
|
||||
# Replace x_t to support inpaint models
|
||||
for param in outer.control_params:
|
||||
if not isinstance(param.used_hint_cond, torch.Tensor):
|
||||
continue
|
||||
if param.used_hint_cond.shape[1] != 4:
|
||||
continue
|
||||
if x.shape[1] != 9:
|
||||
|
|
@ -284,8 +353,14 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
param.used_hint_inpaint_hijack.to(x.dtype).to(x.device)
|
||||
x = torch.cat([x[:, :4, :, :], param.used_hint_inpaint_hijack], dim=1)
|
||||
|
||||
# vram
|
||||
for param in outer.control_params:
|
||||
if param.control_model is not None:
|
||||
if outer.lowvram:
|
||||
param.control_model.to('cpu')
|
||||
|
||||
# A1111 fix for medvram.
|
||||
if shared.cmd_opts.medvram:
|
||||
if shared.cmd_opts.medvram or (getattr(shared.cmd_opts, 'medvram_sdxl', False) and is_sdxl):
|
||||
try:
|
||||
# Trigger the register_forward_pre_hook
|
||||
outer.sd_ldm.model()
|
||||
|
|
@ -303,6 +378,9 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
|
||||
# Handle attention and AdaIn control
|
||||
for param in outer.control_params:
|
||||
if no_high_res_control:
|
||||
continue
|
||||
|
||||
if param.guidance_stopped:
|
||||
continue
|
||||
|
||||
|
|
@ -312,7 +390,7 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
if param.control_model_type not in [ControlModelType.AttentionInjection]:
|
||||
continue
|
||||
|
||||
ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long())
|
||||
ref_xt = predict_q_sample(outer.sd_ldm, param.used_hint_cond_latent, torch.round(timesteps.float()).long())
|
||||
|
||||
# Inpaint Hijack
|
||||
if x.shape[1] == 9:
|
||||
|
|
@ -325,6 +403,12 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
outer.current_style_fidelity = float(param.preprocessor['threshold_a'])
|
||||
outer.current_style_fidelity = max(0.0, min(1.0, outer.current_style_fidelity))
|
||||
|
||||
if is_sdxl:
|
||||
# sdxl's attention hacking is highly unstable.
|
||||
# We have no other methods but to reduce the style_fidelity a bit.
|
||||
# By default, 0.5 ** 3.0 = 0.125
|
||||
outer.current_style_fidelity = outer.current_style_fidelity ** 3.0
|
||||
|
||||
if param.cfg_injection:
|
||||
outer.current_style_fidelity = 1.0
|
||||
elif param.soft_injection or is_in_high_res_fix:
|
||||
|
|
@ -340,11 +424,19 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
outer.gn_auto_machine = AutoMachine.Write
|
||||
outer.gn_auto_machine_weight = param.weight
|
||||
|
||||
outer.original_forward(
|
||||
x=ref_xt.to(devices.dtype_unet),
|
||||
timesteps=timesteps.to(devices.dtype_unet),
|
||||
context=context.to(devices.dtype_unet)
|
||||
)
|
||||
if is_sdxl:
|
||||
outer.original_forward(
|
||||
x=ref_xt.to(devices.dtype_unet),
|
||||
timesteps=timesteps.to(devices.dtype_unet),
|
||||
context=context.to(devices.dtype_unet),
|
||||
y=y
|
||||
)
|
||||
else:
|
||||
outer.original_forward(
|
||||
x=ref_xt.to(devices.dtype_unet),
|
||||
timesteps=timesteps.to(devices.dtype_unet),
|
||||
context=context.to(devices.dtype_unet)
|
||||
)
|
||||
|
||||
outer.attention_auto_machine = AutoMachine.Read
|
||||
outer.gn_auto_machine = AutoMachine.Read
|
||||
|
|
@ -377,21 +469,35 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
with th.no_grad():
|
||||
t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
|
||||
emb = self.time_embed(t_emb)
|
||||
h = x.type(self.dtype)
|
||||
|
||||
if is_sdxl:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
for i, module in enumerate(self.input_blocks):
|
||||
self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3])
|
||||
h = module(h, emb, context)
|
||||
|
||||
if (i + 1) % 3 == 0:
|
||||
t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11]
|
||||
|
||||
if i in t2i_injection:
|
||||
h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack)
|
||||
|
||||
hs.append(h)
|
||||
|
||||
self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3])
|
||||
h = self.middle_block(h, emb, context)
|
||||
|
||||
# U-Net Middle Block
|
||||
h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack)
|
||||
|
||||
if len(total_t2i_adapter_embedding) > 0 and is_sdxl:
|
||||
h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack)
|
||||
|
||||
# U-Net Decoder
|
||||
for i, module in enumerate(self.output_blocks):
|
||||
self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3])
|
||||
h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1)
|
||||
h = module(h, emb, context)
|
||||
|
||||
|
|
@ -407,7 +513,7 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
continue
|
||||
|
||||
k = int(param.preprocessor['threshold_a'])
|
||||
if is_in_high_res_fix:
|
||||
if is_in_high_res_fix and not no_high_res_control:
|
||||
k *= 2
|
||||
|
||||
# Inpaint hijack
|
||||
|
|
@ -454,18 +560,23 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
|
||||
return h
|
||||
|
||||
def move_all_control_model_to_cpu():
|
||||
for param in getattr(outer, 'control_params', []) or []:
|
||||
if isinstance(param.control_model, torch.nn.Module):
|
||||
param.control_model.to("cpu")
|
||||
|
||||
def forward_webui(*args, **kwargs):
|
||||
# webui will handle other compoments
|
||||
try:
|
||||
if shared.cmd_opts.lowvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
|
||||
return forward(*args, **kwargs)
|
||||
except Exception as e:
|
||||
move_all_control_model_to_cpu()
|
||||
raise e
|
||||
finally:
|
||||
if self.lowvram:
|
||||
for param in self.control_params:
|
||||
if isinstance(param.control_model, torch.nn.Module):
|
||||
param.control_model.to("cpu")
|
||||
if outer.lowvram:
|
||||
move_all_control_model_to_cpu()
|
||||
|
||||
def hacked_basic_transformer_inner_forward(self, x, context=None):
|
||||
x_norm1 = self.norm1(x)
|
||||
|
|
@ -492,6 +603,19 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc
|
||||
self.bank = []
|
||||
self.style_cfgs = []
|
||||
if outer.attention_auto_machine == AutoMachine.StyleAlign and not outer.is_in_high_res_fix:
|
||||
# very VRAM hungry - disable at high_res_fix
|
||||
|
||||
def shared_attn1(inner_x):
|
||||
BB, FF, CC = inner_x.shape
|
||||
return self.attn1(inner_x.reshape(1, BB * FF, CC)).reshape(BB, FF, CC)
|
||||
|
||||
uc_layer = shared_attn1(x_norm1[outer.current_uc_indices])
|
||||
c_layer = shared_attn1(x_norm1[outer.current_c_indices])
|
||||
self_attn1 = torch.zeros_like(x_norm1).to(uc_layer)
|
||||
self_attn1[outer.current_uc_indices] = uc_layer
|
||||
self_attn1[outer.current_c_indices] = c_layer
|
||||
del uc_layer, c_layer
|
||||
if self_attn1 is None:
|
||||
self_attn1 = self.attn1(x_norm1, context=self_attention_context)
|
||||
|
||||
|
|
@ -502,7 +626,7 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
|
||||
def hacked_group_norm_forward(self, *args, **kwargs):
|
||||
eps = 1e-6
|
||||
x = self.original_forward(*args, **kwargs)
|
||||
x = self.original_forward_cn_hijack(*args, **kwargs)
|
||||
y = None
|
||||
if outer.gn_auto_machine == AutoMachine.Write:
|
||||
if outer.gn_auto_machine_weight > self.gn_weight:
|
||||
|
|
@ -538,45 +662,76 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_
|
|||
outer.original_forward = model.forward
|
||||
model.forward = forward_webui.__get__(model, UNetModel)
|
||||
|
||||
if model_is_sdxl:
|
||||
register_schedule(sd_ldm)
|
||||
outer.revision_q_sampler = AbstractLowScaleModel()
|
||||
|
||||
need_attention_hijack = False
|
||||
|
||||
for param in outer.control_params:
|
||||
if param.control_model_type in [ControlModelType.AttentionInjection]:
|
||||
need_attention_hijack = True
|
||||
|
||||
if batch_option_style_align:
|
||||
need_attention_hijack = True
|
||||
outer.attention_auto_machine = AutoMachine.StyleAlign
|
||||
outer.gn_auto_machine = AutoMachine.StyleAlign
|
||||
|
||||
all_modules = torch_dfs(model)
|
||||
|
||||
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
|
||||
attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
|
||||
if need_attention_hijack:
|
||||
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlockSGM)]
|
||||
attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
|
||||
|
||||
for i, module in enumerate(attn_modules):
|
||||
if getattr(module, '_original_inner_forward', None) is None:
|
||||
module._original_inner_forward = module._forward
|
||||
module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
|
||||
module.bank = []
|
||||
module.style_cfgs = []
|
||||
module.attn_weight = float(i) / float(len(attn_modules))
|
||||
for i, module in enumerate(attn_modules):
|
||||
if getattr(module, '_original_inner_forward_cn_hijack', None) is None:
|
||||
module._original_inner_forward_cn_hijack = module._forward
|
||||
module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
|
||||
module.bank = []
|
||||
module.style_cfgs = []
|
||||
module.attn_weight = float(i) / float(len(attn_modules))
|
||||
|
||||
gn_modules = [model.middle_block]
|
||||
model.middle_block.gn_weight = 0
|
||||
gn_modules = [model.middle_block]
|
||||
model.middle_block.gn_weight = 0
|
||||
|
||||
input_block_indices = [4, 5, 7, 8, 10, 11]
|
||||
for w, i in enumerate(input_block_indices):
|
||||
module = model.input_blocks[i]
|
||||
module.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
|
||||
gn_modules.append(module)
|
||||
if model_is_sdxl:
|
||||
input_block_indices = [4, 5, 7, 8]
|
||||
output_block_indices = [0, 1, 2, 3, 4, 5]
|
||||
else:
|
||||
input_block_indices = [4, 5, 7, 8, 10, 11]
|
||||
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
for w, i in enumerate(output_block_indices):
|
||||
module = model.output_blocks[i]
|
||||
module.gn_weight = float(w) / float(len(output_block_indices))
|
||||
gn_modules.append(module)
|
||||
for w, i in enumerate(input_block_indices):
|
||||
module = model.input_blocks[i]
|
||||
module.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
|
||||
gn_modules.append(module)
|
||||
|
||||
for i, module in enumerate(gn_modules):
|
||||
if getattr(module, 'original_forward', None) is None:
|
||||
module.original_forward = module.forward
|
||||
module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module)
|
||||
module.mean_bank = []
|
||||
module.var_bank = []
|
||||
module.style_cfgs = []
|
||||
module.gn_weight *= 2
|
||||
for w, i in enumerate(output_block_indices):
|
||||
module = model.output_blocks[i]
|
||||
module.gn_weight = float(w) / float(len(output_block_indices))
|
||||
gn_modules.append(module)
|
||||
|
||||
outer.attn_module_list = attn_modules
|
||||
outer.gn_module_list = gn_modules
|
||||
for i, module in enumerate(gn_modules):
|
||||
if getattr(module, 'original_forward_cn_hijack', None) is None:
|
||||
module.original_forward_cn_hijack = module.forward
|
||||
module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module)
|
||||
module.mean_bank = []
|
||||
module.var_bank = []
|
||||
module.style_cfgs = []
|
||||
module.gn_weight *= 2
|
||||
|
||||
outer.attn_module_list = attn_modules
|
||||
outer.gn_module_list = gn_modules
|
||||
else:
|
||||
for module in all_modules:
|
||||
_original_inner_forward_cn_hijack = getattr(module, '_original_inner_forward_cn_hijack', None)
|
||||
original_forward_cn_hijack = getattr(module, 'original_forward_cn_hijack', None)
|
||||
if _original_inner_forward_cn_hijack is not None:
|
||||
module._forward = _original_inner_forward_cn_hijack
|
||||
if original_forward_cn_hijack is not None:
|
||||
module.forward = original_forward_cn_hijack
|
||||
outer.attn_module_list = []
|
||||
outer.gn_module_list = []
|
||||
|
||||
scripts.script_callbacks.on_cfg_denoiser(self.guidance_schedule_handler)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue