From 96dbc601a6c880571d3a2a1314052d0922114604 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Wed, 6 Dec 2023 21:07:37 -0800 Subject: [PATCH] fix model unload (#2301) fix model unload --- scripts/controlnet.py | 20 +++++++++++++++++--- scripts/controlnet_version.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 283a451..3437bcd 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -36,6 +36,7 @@ from PIL import Image, ImageFilter, ImageOps from scripts.lvminthin import lvmin_thin, nake_nms from scripts.processor import model_free_preprocessors from scripts.controlnet_model_guess import build_model_by_guess +from scripts.hook import torch_dfs # Gradio 3.32 bug fix @@ -44,7 +45,17 @@ gradio_tempfile_path = os.path.join(tempfile.gettempdir(), 'gradio') os.makedirs(gradio_tempfile_path, exist_ok=True) -def clear_all_secondary_control_models(): +def clear_all_secondary_control_models(m): + all_modules = torch_dfs(m) + + for module in all_modules: + _original_inner_forward_cn_hijack = getattr(module, '_original_inner_forward_cn_hijack', None) + original_forward_cn_hijack = getattr(module, 'original_forward_cn_hijack', None) + if _original_inner_forward_cn_hijack is not None: + module._forward = _original_inner_forward_cn_hijack + if original_forward_cn_hijack is not None: + module.forward = original_forward_cn_hijack + clear_all_lllite() clear_all_ip_adapter() @@ -702,7 +713,7 @@ class Script(scripts.Script, metaclass=( self.latest_network.restore() # always clear (~0.05s) - clear_all_secondary_control_models() + clear_all_secondary_control_models(unet) if not batch_hijack.instance.is_batch: self.enabled_units = Script.get_enabled_units(p) @@ -1060,7 +1071,10 @@ class Script(scripts.Script, metaclass=( return def postprocess(self, p, processed, *args): - clear_all_secondary_control_models() + sd_ldm = p.sd_model + unet = sd_ldm.model.diffusion_model + + clear_all_secondary_control_models(unet) self.noise_modifier = None diff --git a/scripts/controlnet_version.py b/scripts/controlnet_version.py index 500321a..ffbaffc 100644 --- a/scripts/controlnet_version.py +++ b/scripts/controlnet_version.py @@ -1,4 +1,4 @@ -version_flag = 'v1.1.421' +version_flag = 'v1.1.422' from scripts.logging import logger