From 73de4b8f0fd9b1d9c5f20dc47e8ee3573dda3c30 Mon Sep 17 00:00:00 2001 From: a2569875 Date: Tue, 25 Jul 2023 18:34:18 +0800 Subject: [PATCH] reduce the chance of crashing in 1.5....... --- composable_lora.py | 33 ++++++++++------ composable_lycoris.py | 21 ++++++---- lora_ext.py | 64 +++++++++++++++++++++++++++++++ scripts/composable_lora_script.py | 10 ++++- 4 files changed, 108 insertions(+), 20 deletions(-) create mode 100644 lora_ext.py diff --git a/composable_lora.py b/composable_lora.py index 549c70c..1b6678b 100644 --- a/composable_lora.py +++ b/composable_lora.py @@ -4,6 +4,7 @@ import torch import composable_lora_step import composable_lycoris import plot_helper +import lora_ext from modules import extra_networks, devices def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention], input, res): @@ -13,6 +14,7 @@ def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n global should_print global first_log_drawing global drawing_lora_first_index + import lora if composable_lycoris.has_webui_lycoris: @@ -28,20 +30,23 @@ def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n log_lora() drawing_lora_first_index = drawing_data[0] - if len(lora.loaded_loras) == 0: + if len(lora_ext.get_loaded_lora()) == 0: return res if hasattr(devices, "cond_cast_unet"): input = devices.cond_cast_unet(input) lora_layer_name_loading : Optional[str] = getattr(compvis_module, 'lora_layer_name', None) + if lora_layer_name_loading is None: + lora_layer_name_loading = getattr(compvis_module, 'network_layer_name', None) if lora_layer_name_loading is None: return res #let it type is actually a string lora_layer_name : str = str(lora_layer_name_loading) del lora_layer_name_loading - num_loras = len(lora.loaded_loras) + lora_loaded_loras = lora_ext.get_loaded_lora() + num_loras = len(lora_loaded_loras) if composable_lycoris.has_webui_lycoris: num_loras += len(lycoris.loaded_lycos) @@ -51,7 +56,7 @@ def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n tmp_check_loras = [] #store which lora are already apply tmp_check_loras.clear() - for m_lora in lora.loaded_loras: + for m_lora in lora_loaded_loras: module = m_lora.modules.get(lora_layer_name, None) if module is None: #fix the lyCORIS issue @@ -164,7 +169,7 @@ def add_step_counters(): def log_lora(): import lora - loaded_loras = lora.loaded_loras + loaded_loras = lora_ext.get_loaded_lora() loaded_lycos = [] if composable_lycoris.has_webui_lycoris: import lycoris @@ -431,9 +436,11 @@ def lora_Linear_forward(self, input): if old_lyco_count > 0 and lyco_count <= 0: clear_cache_lora(self, True) self.old_lyco_count = lyco_count - torch.nn.Linear_forward_before_lyco = lora.lora_Linear_forward + lora_ext.load_lora_ext() + torch.nn.Linear_forward_before_lyco = lora_ext.lora_Linear_forward + torch.nn.Linear_forward_before_network = Linear_forward_before_clora #if lyco_count <= 0: - # return lora.lora_Linear_forward(self, input) + # return lora_ext.lora_Linear_forward(self, input) if 'lyco_notfound' in locals() or 'lyco_notfound' in globals(): if lyco_notfound: backup_Linear_forward = torch.nn.Linear_forward_before_lora @@ -441,7 +448,6 @@ def lora_Linear_forward(self, input): result = lycoris.lyco_Linear_forward(self, input) torch.nn.Linear_forward_before_lora = backup_Linear_forward return result - return lycoris.lyco_Linear_forward(self, input) clear_cache_lora(self, False) if (not self.weight.is_cuda) and input.is_cuda: #if variables not on the same device (between cpu and gpu) @@ -468,9 +474,11 @@ def lora_Conv2d_forward(self, input): if old_lyco_count > 0 and lyco_count <= 0: clear_cache_lora(self, True) self.old_lyco_count = lyco_count - torch.nn.Conv2d_forward_before_lyco = lora.lora_Conv2d_forward + lora_ext.load_lora_ext() + torch.nn.Conv2d_forward_before_lyco = lora_ext.lora_Conv2d_forward + torch.nn.Conv2d_forward_before_network = Conv2d_forward_before_clora #if lyco_count <= 0: - # return lora.lora_Conv2d_forward(self, input) + # return lora_ext.lora_Conv2d_forward(self, input) if 'lyco_notfound' in locals() or 'lyco_notfound' in globals(): if lyco_notfound: backup_Conv2d_forward = torch.nn.Conv2d_forward_before_lora @@ -505,9 +513,12 @@ def lora_MultiheadAttention_forward(self, input): if old_lyco_count > 0 and lyco_count <= 0: clear_cache_lora(self, True) self.old_lyco_count = lyco_count - torch.nn.MultiheadAttention_forward_before_lyco = lora.lora_MultiheadAttention_forward + lora_ext.load_lora_ext() + torch.nn.MultiheadAttention_forward_before_lyco = lora_ext.lora_MultiheadAttention_forward + torch.nn.MultiheadAttention_forward_before_network = MultiheadAttention_forward_before_clora + #if lyco_count <= 0: - # return lora.lora_MultiheadAttention_forward(self, input) + # return lora_ext.lora_MultiheadAttention_forward(self, input) if 'lyco_notfound' in locals() or 'lyco_notfound' in globals(): if lyco_notfound: backup_MultiheadAttention_forward = torch.nn.MultiheadAttention_forward_before_lora diff --git a/composable_lycoris.py b/composable_lycoris.py index 79801b8..e26e97a 100644 --- a/composable_lycoris.py +++ b/composable_lycoris.py @@ -1,6 +1,7 @@ from typing import Optional, Union import re import torch +import lora_ext from modules import shared, devices #support for @@ -22,7 +23,7 @@ def lycoris_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torc del lycoris_layer_name_loading sd_module = shared.sd_model.lora_layer_mapping.get(lycoris_layer_name, None) - num_loras = len(lora.loaded_loras) + len(lycoris.loaded_lycos) + num_loras = len(lora_ext.get_loaded_lora()) + len(lycoris.loaded_lycos) if lora_controller.text_model_encoder_counter == -1: lora_controller.text_model_encoder_counter = len(lora_controller.prompt_loras) * num_loras @@ -105,7 +106,10 @@ def get_lora_patch(module, input, res, lora_layer_name): if inference is not None: return inference else: - converted_module = convert_lycoris(module, shared.sd_model.lora_layer_mapping.get(lora_layer_name, None)) + if hasattr(shared.sd_model, "network_layer_mapping"): + converted_module = convert_lycoris(module, shared.sd_model.network_layer_mapping.get(lora_layer_name, None)) + else: + converted_module = convert_lycoris(module, shared.sd_model.lora_layer_mapping.get(lora_layer_name, None)) if converted_module is not None: return get_lora_inference(converted_module, input) else: @@ -340,7 +344,8 @@ def convert_lycoris(lycoris_module, sd_module): result_module = getattr(lycoris_module, 'lyco_converted_lora_module', None) if result_module is not None: return result_module - if lycoris_module.__class__.__name__ == "LycoUpDownModule" or lycoris_module.__class__.__name__ == "LoraUpDownModule": + if lycoris_module.__class__.__name__ == "LycoUpDownModule" or lycoris_module.__class__.__name__ == "LoraUpDownModule"\ + or lycoris_module.__class__.__name__ == "NetworkModuleLora": result_module = LoraUpDownModule() if (type(sd_module) == torch.nn.Linear or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear @@ -365,7 +370,7 @@ def convert_lycoris(lycoris_module, sd_module): result_module.up_model.weight, result_module.inference ) - elif lycoris_module.__class__.__name__ == "FullModule": + elif lycoris_module.__class__.__name__ == "FullModule" or lycoris_module.__class__.__name__ == "NetworkModuleFull": result_module = FullModule() result_module.weight = lycoris_module.weight#.to(device=devices.device, dtype=devices.dtype) result_module.alpha = lycoris_module.alpha @@ -388,7 +393,7 @@ def convert_lycoris(lycoris_module, sd_module): } setattr(lycoris_module, "lyco_converted_lora_module", result_module) return result_module - elif lycoris_module.__class__.__name__ == "IA3Module": + elif lycoris_module.__class__.__name__ == "IA3Module" or lycoris_module.__class__.__name__ == "NetworkModuleIa3": result_module = IA3Module() result_module.w = lycoris_module.w result_module.alpha = lycoris_module.alpha @@ -401,7 +406,8 @@ def convert_lycoris(lycoris_module, sd_module): result_module.op = torch.nn.functional.linear elif type(sd_module) == torch.nn.Conv2d: result_module.op = torch.nn.functional.conv2d - elif lycoris_module.__class__.__name__ == "LycoHadaModule" or lycoris_module.__class__.__name__ == "LoraHadaModule": + elif lycoris_module.__class__.__name__ == "LycoHadaModule" or lycoris_module.__class__.__name__ == "LoraHadaModule"\ + or lycoris_module.__class__.__name__ == "NetworkModuleHada": result_module = LoraHadaModule() result_module.t1 = lycoris_module.t1 result_module.w1a = lycoris_module.w1a @@ -427,7 +433,8 @@ def convert_lycoris(lycoris_module, sd_module): 'stride': sd_module.stride, 'padding': sd_module.padding } - elif lycoris_module.__class__.__name__ == "LycoKronModule" or lycoris_module.__class__.__name__ == "LoraKronModule" : + elif lycoris_module.__class__.__name__ == "LycoKronModule" or lycoris_module.__class__.__name__ == "LoraKronModule"\ + or lycoris_module.__class__.__name__ == "NetworkModuleLokr" : result_module = LoraKronModule() result_module.w1 = lycoris_module.w1 result_module.w1a = lycoris_module.w1a diff --git a/lora_ext.py b/lora_ext.py new file mode 100644 index 0000000..5e50662 --- /dev/null +++ b/lora_ext.py @@ -0,0 +1,64 @@ +lora_Linear_forward = None +lora_Linear_load_state_dict = None +lora_Conv2d_forward = None +lora_Conv2d_load_state_dict = None +lora_MultiheadAttention_forward = None +lora_MultiheadAttention_load_state_dict = None +is_sd_1_5 = False +def get_loaded_lora(): + global is_sd_1_5 + if lora_Linear_forward is None: + load_lora_ext() + import lora + try: + import networks + is_sd_1_5 = True + except ImportError: + pass + if is_sd_1_5: + return networks.loaded_networks + return lora.loaded_loras + +def load_lora_ext(): + global is_sd_1_5 + global lora_Linear_forward + global lora_Linear_load_state_dict + global lora_Conv2d_forward + global lora_Conv2d_load_state_dict + global lora_MultiheadAttention_forward + global lora_MultiheadAttention_load_state_dict + if lora_Linear_forward is not None: + return + import lora + is_sd_1_5 = False + try: + import networks + is_sd_1_5 = True + except ImportError: + pass + if is_sd_1_5: + if hasattr(networks, "network_Linear_forward"): + lora_Linear_forward = networks.network_Linear_forward + if hasattr(networks, "network_Linear_load_state_dict"): + lora_Linear_load_state_dict = networks.network_Linear_load_state_dict + if hasattr(networks, "network_Conv2d_forward"): + lora_Conv2d_forward = networks.network_Conv2d_forward + if hasattr(networks, "network_Conv2d_load_state_dict"): + lora_Conv2d_load_state_dict = networks.network_Conv2d_load_state_dict + if hasattr(networks, "network_MultiheadAttention_forward"): + lora_MultiheadAttention_forward = networks.network_MultiheadAttention_forward + if hasattr(networks, "network_MultiheadAttention_load_state_dict"): + lora_MultiheadAttention_load_state_dict = networks.network_MultiheadAttention_load_state_dict + else: + if hasattr(networks, "network_Linear_forward"): + lora_Linear_forward = lora.lora_Linear_forward + if hasattr(networks, "network_Linear_load_state_dict"): + lora_Linear_load_state_dict = lora.lora_Linear_load_state_dict + if hasattr(networks, "network_Conv2d_forward"): + lora_Conv2d_forward = lora.lora_Conv2d_forward + if hasattr(networks, "network_Conv2d_load_state_dict"): + lora_Conv2d_load_state_dict = lora.lora_Conv2d_load_state_dict + if hasattr(networks, "network_MultiheadAttention_forward"): + lora_MultiheadAttention_forward = lora.lora_MultiheadAttention_forward + if hasattr(networks, "network_MultiheadAttention_load_state_dict"): + lora_MultiheadAttention_load_state_dict = lora.lora_MultiheadAttention_load_state_dict diff --git a/scripts/composable_lora_script.py b/scripts/composable_lora_script.py index 6ba71b1..f0c98cd 100644 --- a/scripts/composable_lora_script.py +++ b/scripts/composable_lora_script.py @@ -6,6 +6,7 @@ import gradio as gr import composable_lora import composable_lora_function_handler +import lora_ext import modules.scripts as scripts from modules import script_callbacks from modules.processing import StableDiffusionProcessing @@ -56,8 +57,8 @@ if hasattr(torch.nn, 'Linear_forward_before_lyco'): else: composable_lora.lyco_notfound = True -torch.nn.Linear.forward = composable_lora.lora_Linear_forward -torch.nn.Conv2d.forward = composable_lora.lora_Conv2d_forward +#torch.nn.Linear.forward = composable_lora.lora_Linear_forward +#torch.nn.Conv2d.forward = composable_lora.lora_Conv2d_forward def check_install_state(): if not hasattr(composable_lora, "noop"): @@ -97,6 +98,11 @@ class ComposableLoraScript(scripts.Script): opt_uc_text_model_encoder: bool, opt_uc_diffusion_model: bool, opt_plot_lora_weight: bool, opt_single_no_uc: bool, opt_hires_step_as_global: bool): + lora_ext.load_lora_ext() + if lora_ext.is_sd_1_5: + import composable_lycoris + if composable_lycoris.has_webui_lycoris: + print("Error! in sd webui 1.5, composable-lora not support with sd-webui-lycoris extension.") composable_lora.enabled = enabled composable_lora.opt_uc_text_model_encoder = opt_uc_text_model_encoder composable_lora.opt_uc_diffusion_model = opt_uc_diffusion_model