diff --git a/scripts/hook.py b/scripts/hook.py index 9deec11..2dfd5a3 100644 --- a/scripts/hook.py +++ b/scripts/hook.py @@ -145,6 +145,8 @@ class UnetHook(nn.Module): total_extra_cond = torch.zeros([0, context.shape[-1]]).to(devices.get_device_for("controlnet")) only_mid_control = outer.only_mid_control require_inpaint_hijack = False + + is_in_high_res_fix = False # handle external cond first for param in outer.control_params: @@ -163,6 +165,7 @@ class UnetHook(nn.Module): else: # we are in high-res path param.used_hint_cond = param.hr_hint_cond + is_in_high_res_fix = True if param.guidance_stopped or not param.is_extra_cond: continue @@ -218,12 +221,13 @@ class UnetHook(nn.Module): uc_mask = param.generate_uc_mask(query_size, dtype=x.dtype, device=x.device)[:, None, None, None] control = [c * uc_mask for c in control] - if param.guess_mode: + if param.guess_mode or is_in_high_res_fix: if param.is_adapter: # see https://github.com/Mikubill/sd-webui-controlnet/issues/269 control_scales = param.weight * [0.25, 0.62, 0.825, 1.0] else: control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)] + if param.advanced_weighting is not None: control_scales = param.advanced_weighting