diff --git a/README.md b/README.md index 566fe83..cb7115c 100644 --- a/README.md +++ b/README.md @@ -58,8 +58,9 @@ The latest version `v3.1` is synced & tested with: ⚪ Fixups -- 2023/07/05: sync sd-webui-controlnet to `v1.1.229` -- 2023/04/30: update controlnet core to `v1.1.116` +- 2023/12/29: fix bad ffmpeg envvar, update controlnet to `v1.1.424` +- 2023/07/05: update controlnet to `v1.1.229` +- 2023/04/30: update controlnet to `v1.1.116` - 2023/03/29: `v2.4` bug fixes on script hook, now working correctly with extra networks & [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) - 2023/01/31: keep up with webui's updates, (issue #14: `ImportError: cannot import name 'single_sample_to_image'`) - 2023/01/28: keep up with webui's updates, extra-networks rework diff --git a/scripts/controlnet_travel.py b/scripts/controlnet_travel.py index cf77acc..d83db6f 100644 --- a/scripts/controlnet_travel.py +++ b/scripts/controlnet_travel.py @@ -1,5 +1,5 @@ # This extension works with [Mikubill/sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) -# version: v1.1.229 +# version: v1.1.424 LOG_PREFIX = '[ControlNet-Travel]' @@ -31,6 +31,12 @@ if 'externel repo sanity check': # ↑↑↑ EXIT EARLY IF EXTERNAL REPOSITORY NOT FOUND ↑↑↑ +TOOL_PATH = ME_PATH / 'tools' +paths_ext = [] +paths_ext.append(str(TOOL_PATH)) +paths_ext.append(str(TOOL_PATH / 'rife-ncnn-vulkan')) +import os +os.environ['PATH'] += os.path.pathsep + os.path.pathsep.join(paths_ext) import sys from subprocess import Popen @@ -90,11 +96,13 @@ def run_cmd(cmd:str) -> bool: # ↓↓↓ the following is modified from 'sd-webui-controlnet/scripts/hook.py' ↓↓↓ -def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_params:List[ControlParams], process:Processing): +def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_params:List[ControlParams], process:Processing, batch_option_uint_separate=False, batch_option_style_align=False): self.model = model self.sd_ldm = sd_ldm self.control_params = control_params + model_is_sdxl = getattr(self.sd_ldm, 'is_sdxl', False) + outer = self def process_sample(*args, **kwargs): @@ -112,9 +120,13 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ return process.sample_before_CN_hack(*args, **kwargs) # NOTE: ↓↓↓ only hack this method ↓↓↓ - def forward(self:UNetModel, x:Tensor, timesteps:Tensor=None, context:Tensor=None, **kwargs): - total_controlnet_embedding = [0.0] * 13 + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + is_sdxl = y is not None and model_is_sdxl total_t2i_adapter_embedding = [0.0] * 4 + if is_sdxl: + total_controlnet_embedding = [0.0] * 10 + else: + total_controlnet_embedding = [0.0] * 13 require_inpaint_hijack = False is_in_high_res_fix = False batch_size = int(x.shape[0]) @@ -127,19 +139,41 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ kwargs: dict # {} # Handle cond-uncond marker - cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context) + cond_mark, outer.current_uc_indices, outer.current_c_indices, context = unmark_prompt_context(context) + outer.model.cond_mark = cond_mark # logger.info(str(cond_mark[:, 0, 0, 0].detach().cpu().numpy().tolist()) + ' - ' + str(outer.current_uc_indices)) + # Revision + if is_sdxl: + revision_y1280 = 0 + + for param in outer.control_params: + if param.guidance_stopped: + continue + if param.control_model_type == ControlModelType.ReVision: + if param.vision_hint_count is None: + k = torch.Tensor([int(param.preprocessor['threshold_a'] * 1000)]).to(param.hint_cond).long().clip(0, 999) + param.vision_hint_count = outer.revision_q_sampler.q_sample(param.hint_cond, k) + revision_emb = param.vision_hint_count + if isinstance(revision_emb, torch.Tensor): + revision_y1280 += revision_emb * param.weight + + if isinstance(revision_y1280, torch.Tensor): + y[:, :1280] = revision_y1280 * cond_mark[:, :, 0, 0] + if any('ignore_prompt' in param.preprocessor['name'] for param in outer.control_params) \ + or (getattr(process, 'prompt', '') == '' and getattr(process, 'negative_prompt', '') == ''): + context = torch.zeros_like(context) + # High-res fix for param in outer.control_params: # select which hint_cond to use if param.used_hint_cond is None: - param.used_hint_cond = param.hint_cond # NOTE: input hint cond tensor, [1, 3, 512, 512] + param.used_hint_cond = param.hint_cond param.used_hint_cond_latent = None param.used_hint_inpaint_hijack = None # has high-res fix - if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4: + if isinstance(param.hr_hint_cond, torch.Tensor) and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4: _, _, h_lr, w_lr = param.hint_cond.shape _, _, h_hr, w_hr = param.hr_hint_cond.shape _, _, h, w = x.shape @@ -157,6 +191,10 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ param.used_hint_cond_latent = None param.used_hint_inpaint_hijack = None + self.is_in_high_res_fix = is_in_high_res_fix + outer.is_in_high_res_fix = is_in_high_res_fix + no_high_res_control = is_in_high_res_fix and shared.opts.data.get("control_net_no_high_res_fix", False) + # NOTE: hint shallow fusion, overwrite param.used_hint_cond for i, param in enumerate(outer.control_params): if interp_alpha == 0.0: # collect hind_cond on key frames @@ -174,16 +212,31 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ and 'inpaint_only' not in param.preprocessor['name']: continue param.used_hint_cond_latent = outer.call_vae_using_process(process, param.used_hint_cond, batch_size=batch_size) - + + # vram + for param in outer.control_params: + if getattr(param.control_model, 'disable_memory_management', False): + continue + + if param.control_model is not None: + if outer.lowvram and is_sdxl and hasattr(param.control_model, 'aggressive_lowvram'): + param.control_model.aggressive_lowvram() + elif hasattr(param.control_model, 'fullvram'): + param.control_model.fullvram() + elif hasattr(param.control_model, 'to'): + param.control_model.to(devices.get_device_for("controlnet")) + # handle prompt token control for param in outer.control_params: + if no_high_res_control: + continue + if param.guidance_stopped: continue if param.control_model_type not in [ControlModelType.T2I_StyleAdapter]: continue - param.control_model.to(devices.get_device_for("controlnet")) control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context) control = torch.cat([control.clone() for _ in range(batch_size)], dim=0) control *= param.weight @@ -191,14 +244,16 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ context = torch.cat([context, control.clone()], dim=1) # handle ControlNet / T2I_Adapter - for param in outer.control_params: + for param_index, param in enumerate(outer.control_params): + if no_high_res_control: + continue + if param.guidance_stopped: continue if param.control_model_type not in [ControlModelType.ControlNet, ControlModelType.T2I_Adapter]: continue - param.control_model.to(devices.get_device_for("controlnet")) # inpaint model workaround x_in = x control_model = param.control_model.control_model @@ -220,12 +275,12 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ m = (m > 0.5).float() hint = c * (1 - m) - m - # NOTE: len(control) == 13, control[i]:Tensor - control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context) - control_scales = ([param.weight] * 13) + control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context, y=y) - if outer.lowvram: - param.control_model.to("cpu") + if is_sdxl: + control_scales = [param.weight] * 10 + else: + control_scales = [param.weight] * 13 if param.cfg_injection or param.global_average_pooling: if param.control_model_type == ControlModelType.T2I_Adapter: @@ -250,6 +305,9 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ elif param.control_model_type == ControlModelType.ControlNet: control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)] + if is_sdxl and param.control_model_type == ControlModelType.ControlNet: + control_scales = control_scales[:10] + if param.advanced_weighting is not None: control_scales = param.advanced_weighting @@ -264,10 +322,21 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ if param.control_model_type == ControlModelType.T2I_Adapter: target = total_t2i_adapter_embedding if target is not None: - target[idx] = item + target[idx] + if batch_option_uint_separate: + for pi, ci in enumerate(outer.current_c_indices): + if pi % len(outer.control_params) != param_index: + item[ci] = 0 + for pi, ci in enumerate(outer.current_uc_indices): + if pi % len(outer.control_params) != param_index: + item[ci] = 0 + target[idx] = item + target[idx] + else: + target[idx] = item + target[idx] # Replace x_t to support inpaint models for param in outer.control_params: + if not isinstance(param.used_hint_cond, torch.Tensor): + continue if param.used_hint_cond.shape[1] != 4: continue if x.shape[1] != 9: @@ -284,8 +353,14 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ param.used_hint_inpaint_hijack.to(x.dtype).to(x.device) x = torch.cat([x[:, :4, :, :], param.used_hint_inpaint_hijack], dim=1) + # vram + for param in outer.control_params: + if param.control_model is not None: + if outer.lowvram: + param.control_model.to('cpu') + # A1111 fix for medvram. - if shared.cmd_opts.medvram: + if shared.cmd_opts.medvram or (getattr(shared.cmd_opts, 'medvram_sdxl', False) and is_sdxl): try: # Trigger the register_forward_pre_hook outer.sd_ldm.model() @@ -303,6 +378,9 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ # Handle attention and AdaIn control for param in outer.control_params: + if no_high_res_control: + continue + if param.guidance_stopped: continue @@ -312,7 +390,7 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ if param.control_model_type not in [ControlModelType.AttentionInjection]: continue - ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long()) + ref_xt = predict_q_sample(outer.sd_ldm, param.used_hint_cond_latent, torch.round(timesteps.float()).long()) # Inpaint Hijack if x.shape[1] == 9: @@ -325,6 +403,12 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ outer.current_style_fidelity = float(param.preprocessor['threshold_a']) outer.current_style_fidelity = max(0.0, min(1.0, outer.current_style_fidelity)) + if is_sdxl: + # sdxl's attention hacking is highly unstable. + # We have no other methods but to reduce the style_fidelity a bit. + # By default, 0.5 ** 3.0 = 0.125 + outer.current_style_fidelity = outer.current_style_fidelity ** 3.0 + if param.cfg_injection: outer.current_style_fidelity = 1.0 elif param.soft_injection or is_in_high_res_fix: @@ -340,11 +424,19 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ outer.gn_auto_machine = AutoMachine.Write outer.gn_auto_machine_weight = param.weight - outer.original_forward( - x=ref_xt.to(devices.dtype_unet), - timesteps=timesteps.to(devices.dtype_unet), - context=context.to(devices.dtype_unet) - ) + if is_sdxl: + outer.original_forward( + x=ref_xt.to(devices.dtype_unet), + timesteps=timesteps.to(devices.dtype_unet), + context=context.to(devices.dtype_unet), + y=y + ) + else: + outer.original_forward( + x=ref_xt.to(devices.dtype_unet), + timesteps=timesteps.to(devices.dtype_unet), + context=context.to(devices.dtype_unet) + ) outer.attention_auto_machine = AutoMachine.Read outer.gn_auto_machine = AutoMachine.Read @@ -377,21 +469,35 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ with th.no_grad(): t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False)) emb = self.time_embed(t_emb) - h = x.type(self.dtype) + + if is_sdxl: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x for i, module in enumerate(self.input_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) h = module(h, emb, context) - if (i + 1) % 3 == 0: + t2i_injection = [3, 5, 8] if is_sdxl else [2, 5, 8, 11] + + if i in t2i_injection: h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack) hs.append(h) + + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) h = self.middle_block(h, emb, context) # U-Net Middle Block h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack) + if len(total_t2i_adapter_embedding) > 0 and is_sdxl: + h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack) + # U-Net Decoder for i, module in enumerate(self.output_blocks): + self.current_h_shape = (h.shape[0], h.shape[1], h.shape[2], h.shape[3]) h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1) h = module(h, emb, context) @@ -407,7 +513,7 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ continue k = int(param.preprocessor['threshold_a']) - if is_in_high_res_fix: + if is_in_high_res_fix and not no_high_res_control: k *= 2 # Inpaint hijack @@ -454,18 +560,23 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ return h + def move_all_control_model_to_cpu(): + for param in getattr(outer, 'control_params', []) or []: + if isinstance(param.control_model, torch.nn.Module): + param.control_model.to("cpu") + def forward_webui(*args, **kwargs): # webui will handle other compoments try: if shared.cmd_opts.lowvram: lowvram.send_everything_to_cpu() - return forward(*args, **kwargs) + except Exception as e: + move_all_control_model_to_cpu() + raise e finally: - if self.lowvram: - for param in self.control_params: - if isinstance(param.control_model, torch.nn.Module): - param.control_model.to("cpu") + if outer.lowvram: + move_all_control_model_to_cpu() def hacked_basic_transformer_inner_forward(self, x, context=None): x_norm1 = self.norm1(x) @@ -492,6 +603,19 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc self.bank = [] self.style_cfgs = [] + if outer.attention_auto_machine == AutoMachine.StyleAlign and not outer.is_in_high_res_fix: + # very VRAM hungry - disable at high_res_fix + + def shared_attn1(inner_x): + BB, FF, CC = inner_x.shape + return self.attn1(inner_x.reshape(1, BB * FF, CC)).reshape(BB, FF, CC) + + uc_layer = shared_attn1(x_norm1[outer.current_uc_indices]) + c_layer = shared_attn1(x_norm1[outer.current_c_indices]) + self_attn1 = torch.zeros_like(x_norm1).to(uc_layer) + self_attn1[outer.current_uc_indices] = uc_layer + self_attn1[outer.current_c_indices] = c_layer + del uc_layer, c_layer if self_attn1 is None: self_attn1 = self.attn1(x_norm1, context=self_attention_context) @@ -502,7 +626,7 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ def hacked_group_norm_forward(self, *args, **kwargs): eps = 1e-6 - x = self.original_forward(*args, **kwargs) + x = self.original_forward_cn_hijack(*args, **kwargs) y = None if outer.gn_auto_machine == AutoMachine.Write: if outer.gn_auto_machine_weight > self.gn_weight: @@ -538,45 +662,76 @@ def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_ outer.original_forward = model.forward model.forward = forward_webui.__get__(model, UNetModel) + if model_is_sdxl: + register_schedule(sd_ldm) + outer.revision_q_sampler = AbstractLowScaleModel() + + need_attention_hijack = False + + for param in outer.control_params: + if param.control_model_type in [ControlModelType.AttentionInjection]: + need_attention_hijack = True + + if batch_option_style_align: + need_attention_hijack = True + outer.attention_auto_machine = AutoMachine.StyleAlign + outer.gn_auto_machine = AutoMachine.StyleAlign + all_modules = torch_dfs(model) - attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)] - attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0]) + if need_attention_hijack: + attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlockSGM)] + attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0]) - for i, module in enumerate(attn_modules): - if getattr(module, '_original_inner_forward', None) is None: - module._original_inner_forward = module._forward - module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) - module.bank = [] - module.style_cfgs = [] - module.attn_weight = float(i) / float(len(attn_modules)) + for i, module in enumerate(attn_modules): + if getattr(module, '_original_inner_forward_cn_hijack', None) is None: + module._original_inner_forward_cn_hijack = module._forward + module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) + module.bank = [] + module.style_cfgs = [] + module.attn_weight = float(i) / float(len(attn_modules)) - gn_modules = [model.middle_block] - model.middle_block.gn_weight = 0 + gn_modules = [model.middle_block] + model.middle_block.gn_weight = 0 - input_block_indices = [4, 5, 7, 8, 10, 11] - for w, i in enumerate(input_block_indices): - module = model.input_blocks[i] - module.gn_weight = 1.0 - float(w) / float(len(input_block_indices)) - gn_modules.append(module) + if model_is_sdxl: + input_block_indices = [4, 5, 7, 8] + output_block_indices = [0, 1, 2, 3, 4, 5] + else: + input_block_indices = [4, 5, 7, 8, 10, 11] + output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7] - output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7] - for w, i in enumerate(output_block_indices): - module = model.output_blocks[i] - module.gn_weight = float(w) / float(len(output_block_indices)) - gn_modules.append(module) + for w, i in enumerate(input_block_indices): + module = model.input_blocks[i] + module.gn_weight = 1.0 - float(w) / float(len(input_block_indices)) + gn_modules.append(module) - for i, module in enumerate(gn_modules): - if getattr(module, 'original_forward', None) is None: - module.original_forward = module.forward - module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module) - module.mean_bank = [] - module.var_bank = [] - module.style_cfgs = [] - module.gn_weight *= 2 + for w, i in enumerate(output_block_indices): + module = model.output_blocks[i] + module.gn_weight = float(w) / float(len(output_block_indices)) + gn_modules.append(module) - outer.attn_module_list = attn_modules - outer.gn_module_list = gn_modules + for i, module in enumerate(gn_modules): + if getattr(module, 'original_forward_cn_hijack', None) is None: + module.original_forward_cn_hijack = module.forward + module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module) + module.mean_bank = [] + module.var_bank = [] + module.style_cfgs = [] + module.gn_weight *= 2 + + outer.attn_module_list = attn_modules + outer.gn_module_list = gn_modules + else: + for module in all_modules: + _original_inner_forward_cn_hijack = getattr(module, '_original_inner_forward_cn_hijack', None) + original_forward_cn_hijack = getattr(module, 'original_forward_cn_hijack', None) + if _original_inner_forward_cn_hijack is not None: + module._forward = _original_inner_forward_cn_hijack + if original_forward_cn_hijack is not None: + module.forward = original_forward_cn_hijack + outer.attn_module_list = [] + outer.gn_module_list = [] scripts.script_callbacks.on_cfg_denoiser(self.guidance_schedule_handler)