fix model unload (#2301)

fix model unload
pull/2304/head
lllyasviel 2023-12-06 21:07:37 -08:00 committed by GitHub
parent a6edc5d97b
commit 96dbc601a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 4 deletions

View File

@ -36,6 +36,7 @@ from PIL import Image, ImageFilter, ImageOps
from scripts.lvminthin import lvmin_thin, nake_nms from scripts.lvminthin import lvmin_thin, nake_nms
from scripts.processor import model_free_preprocessors from scripts.processor import model_free_preprocessors
from scripts.controlnet_model_guess import build_model_by_guess from scripts.controlnet_model_guess import build_model_by_guess
from scripts.hook import torch_dfs
# Gradio 3.32 bug fix # Gradio 3.32 bug fix
@ -44,7 +45,17 @@ gradio_tempfile_path = os.path.join(tempfile.gettempdir(), 'gradio')
os.makedirs(gradio_tempfile_path, exist_ok=True) os.makedirs(gradio_tempfile_path, exist_ok=True)
def clear_all_secondary_control_models(): def clear_all_secondary_control_models(m):
all_modules = torch_dfs(m)
for module in all_modules:
_original_inner_forward_cn_hijack = getattr(module, '_original_inner_forward_cn_hijack', None)
original_forward_cn_hijack = getattr(module, 'original_forward_cn_hijack', None)
if _original_inner_forward_cn_hijack is not None:
module._forward = _original_inner_forward_cn_hijack
if original_forward_cn_hijack is not None:
module.forward = original_forward_cn_hijack
clear_all_lllite() clear_all_lllite()
clear_all_ip_adapter() clear_all_ip_adapter()
@ -702,7 +713,7 @@ class Script(scripts.Script, metaclass=(
self.latest_network.restore() self.latest_network.restore()
# always clear (~0.05s) # always clear (~0.05s)
clear_all_secondary_control_models() clear_all_secondary_control_models(unet)
if not batch_hijack.instance.is_batch: if not batch_hijack.instance.is_batch:
self.enabled_units = Script.get_enabled_units(p) self.enabled_units = Script.get_enabled_units(p)
@ -1060,7 +1071,10 @@ class Script(scripts.Script, metaclass=(
return return
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
clear_all_secondary_control_models() sd_ldm = p.sd_model
unet = sd_ldm.model.diffusion_model
clear_all_secondary_control_models(unet)
self.noise_modifier = None self.noise_modifier = None

View File

@ -1,4 +1,4 @@
version_flag = 'v1.1.421' version_flag = 'v1.1.422'
from scripts.logging import logger from scripts.logging import logger