reduce the chance of crashing in 1.5.......

main
a2569875 2023-07-25 18:34:18 +08:00
parent e8f461f0e9
commit 73de4b8f0f
4 changed files with 108 additions and 20 deletions

View File

@ -4,6 +4,7 @@ import torch
import composable_lora_step
import composable_lycoris
import plot_helper
import lora_ext
from modules import extra_networks, devices
def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention], input, res):
@ -13,6 +14,7 @@ def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
global should_print
global first_log_drawing
global drawing_lora_first_index
import lora
if composable_lycoris.has_webui_lycoris:
@ -28,20 +30,23 @@ def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
log_lora()
drawing_lora_first_index = drawing_data[0]
if len(lora.loaded_loras) == 0:
if len(lora_ext.get_loaded_lora()) == 0:
return res
if hasattr(devices, "cond_cast_unet"):
input = devices.cond_cast_unet(input)
lora_layer_name_loading : Optional[str] = getattr(compvis_module, 'lora_layer_name', None)
if lora_layer_name_loading is None:
lora_layer_name_loading = getattr(compvis_module, 'network_layer_name', None)
if lora_layer_name_loading is None:
return res
#let it type is actually a string
lora_layer_name : str = str(lora_layer_name_loading)
del lora_layer_name_loading
num_loras = len(lora.loaded_loras)
lora_loaded_loras = lora_ext.get_loaded_lora()
num_loras = len(lora_loaded_loras)
if composable_lycoris.has_webui_lycoris:
num_loras += len(lycoris.loaded_lycos)
@ -51,7 +56,7 @@ def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
tmp_check_loras = [] #store which lora are already apply
tmp_check_loras.clear()
for m_lora in lora.loaded_loras:
for m_lora in lora_loaded_loras:
module = m_lora.modules.get(lora_layer_name, None)
if module is None:
#fix the lyCORIS issue
@ -164,7 +169,7 @@ def add_step_counters():
def log_lora():
import lora
loaded_loras = lora.loaded_loras
loaded_loras = lora_ext.get_loaded_lora()
loaded_lycos = []
if composable_lycoris.has_webui_lycoris:
import lycoris
@ -431,9 +436,11 @@ def lora_Linear_forward(self, input):
if old_lyco_count > 0 and lyco_count <= 0:
clear_cache_lora(self, True)
self.old_lyco_count = lyco_count
torch.nn.Linear_forward_before_lyco = lora.lora_Linear_forward
lora_ext.load_lora_ext()
torch.nn.Linear_forward_before_lyco = lora_ext.lora_Linear_forward
torch.nn.Linear_forward_before_network = Linear_forward_before_clora
#if lyco_count <= 0:
# return lora.lora_Linear_forward(self, input)
# return lora_ext.lora_Linear_forward(self, input)
if 'lyco_notfound' in locals() or 'lyco_notfound' in globals():
if lyco_notfound:
backup_Linear_forward = torch.nn.Linear_forward_before_lora
@ -441,7 +448,6 @@ def lora_Linear_forward(self, input):
result = lycoris.lyco_Linear_forward(self, input)
torch.nn.Linear_forward_before_lora = backup_Linear_forward
return result
return lycoris.lyco_Linear_forward(self, input)
clear_cache_lora(self, False)
if (not self.weight.is_cuda) and input.is_cuda: #if variables not on the same device (between cpu and gpu)
@ -468,9 +474,11 @@ def lora_Conv2d_forward(self, input):
if old_lyco_count > 0 and lyco_count <= 0:
clear_cache_lora(self, True)
self.old_lyco_count = lyco_count
torch.nn.Conv2d_forward_before_lyco = lora.lora_Conv2d_forward
lora_ext.load_lora_ext()
torch.nn.Conv2d_forward_before_lyco = lora_ext.lora_Conv2d_forward
torch.nn.Conv2d_forward_before_network = Conv2d_forward_before_clora
#if lyco_count <= 0:
# return lora.lora_Conv2d_forward(self, input)
# return lora_ext.lora_Conv2d_forward(self, input)
if 'lyco_notfound' in locals() or 'lyco_notfound' in globals():
if lyco_notfound:
backup_Conv2d_forward = torch.nn.Conv2d_forward_before_lora
@ -505,9 +513,12 @@ def lora_MultiheadAttention_forward(self, input):
if old_lyco_count > 0 and lyco_count <= 0:
clear_cache_lora(self, True)
self.old_lyco_count = lyco_count
torch.nn.MultiheadAttention_forward_before_lyco = lora.lora_MultiheadAttention_forward
lora_ext.load_lora_ext()
torch.nn.MultiheadAttention_forward_before_lyco = lora_ext.lora_MultiheadAttention_forward
torch.nn.MultiheadAttention_forward_before_network = MultiheadAttention_forward_before_clora
#if lyco_count <= 0:
# return lora.lora_MultiheadAttention_forward(self, input)
# return lora_ext.lora_MultiheadAttention_forward(self, input)
if 'lyco_notfound' in locals() or 'lyco_notfound' in globals():
if lyco_notfound:
backup_MultiheadAttention_forward = torch.nn.MultiheadAttention_forward_before_lora

View File

@ -1,6 +1,7 @@
from typing import Optional, Union
import re
import torch
import lora_ext
from modules import shared, devices
#support for <lyco:MODEL>
@ -22,7 +23,7 @@ def lycoris_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torc
del lycoris_layer_name_loading
sd_module = shared.sd_model.lora_layer_mapping.get(lycoris_layer_name, None)
num_loras = len(lora.loaded_loras) + len(lycoris.loaded_lycos)
num_loras = len(lora_ext.get_loaded_lora()) + len(lycoris.loaded_lycos)
if lora_controller.text_model_encoder_counter == -1:
lora_controller.text_model_encoder_counter = len(lora_controller.prompt_loras) * num_loras
@ -105,7 +106,10 @@ def get_lora_patch(module, input, res, lora_layer_name):
if inference is not None:
return inference
else:
converted_module = convert_lycoris(module, shared.sd_model.lora_layer_mapping.get(lora_layer_name, None))
if hasattr(shared.sd_model, "network_layer_mapping"):
converted_module = convert_lycoris(module, shared.sd_model.network_layer_mapping.get(lora_layer_name, None))
else:
converted_module = convert_lycoris(module, shared.sd_model.lora_layer_mapping.get(lora_layer_name, None))
if converted_module is not None:
return get_lora_inference(converted_module, input)
else:
@ -340,7 +344,8 @@ def convert_lycoris(lycoris_module, sd_module):
result_module = getattr(lycoris_module, 'lyco_converted_lora_module', None)
if result_module is not None:
return result_module
if lycoris_module.__class__.__name__ == "LycoUpDownModule" or lycoris_module.__class__.__name__ == "LoraUpDownModule":
if lycoris_module.__class__.__name__ == "LycoUpDownModule" or lycoris_module.__class__.__name__ == "LoraUpDownModule"\
or lycoris_module.__class__.__name__ == "NetworkModuleLora":
result_module = LoraUpDownModule()
if (type(sd_module) == torch.nn.Linear
or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear
@ -365,7 +370,7 @@ def convert_lycoris(lycoris_module, sd_module):
result_module.up_model.weight,
result_module.inference
)
elif lycoris_module.__class__.__name__ == "FullModule":
elif lycoris_module.__class__.__name__ == "FullModule" or lycoris_module.__class__.__name__ == "NetworkModuleFull":
result_module = FullModule()
result_module.weight = lycoris_module.weight#.to(device=devices.device, dtype=devices.dtype)
result_module.alpha = lycoris_module.alpha
@ -388,7 +393,7 @@ def convert_lycoris(lycoris_module, sd_module):
}
setattr(lycoris_module, "lyco_converted_lora_module", result_module)
return result_module
elif lycoris_module.__class__.__name__ == "IA3Module":
elif lycoris_module.__class__.__name__ == "IA3Module" or lycoris_module.__class__.__name__ == "NetworkModuleIa3":
result_module = IA3Module()
result_module.w = lycoris_module.w
result_module.alpha = lycoris_module.alpha
@ -401,7 +406,8 @@ def convert_lycoris(lycoris_module, sd_module):
result_module.op = torch.nn.functional.linear
elif type(sd_module) == torch.nn.Conv2d:
result_module.op = torch.nn.functional.conv2d
elif lycoris_module.__class__.__name__ == "LycoHadaModule" or lycoris_module.__class__.__name__ == "LoraHadaModule":
elif lycoris_module.__class__.__name__ == "LycoHadaModule" or lycoris_module.__class__.__name__ == "LoraHadaModule"\
or lycoris_module.__class__.__name__ == "NetworkModuleHada":
result_module = LoraHadaModule()
result_module.t1 = lycoris_module.t1
result_module.w1a = lycoris_module.w1a
@ -427,7 +433,8 @@ def convert_lycoris(lycoris_module, sd_module):
'stride': sd_module.stride,
'padding': sd_module.padding
}
elif lycoris_module.__class__.__name__ == "LycoKronModule" or lycoris_module.__class__.__name__ == "LoraKronModule" :
elif lycoris_module.__class__.__name__ == "LycoKronModule" or lycoris_module.__class__.__name__ == "LoraKronModule"\
or lycoris_module.__class__.__name__ == "NetworkModuleLokr" :
result_module = LoraKronModule()
result_module.w1 = lycoris_module.w1
result_module.w1a = lycoris_module.w1a

64
lora_ext.py Normal file
View File

@ -0,0 +1,64 @@
lora_Linear_forward = None
lora_Linear_load_state_dict = None
lora_Conv2d_forward = None
lora_Conv2d_load_state_dict = None
lora_MultiheadAttention_forward = None
lora_MultiheadAttention_load_state_dict = None
is_sd_1_5 = False
def get_loaded_lora():
global is_sd_1_5
if lora_Linear_forward is None:
load_lora_ext()
import lora
try:
import networks
is_sd_1_5 = True
except ImportError:
pass
if is_sd_1_5:
return networks.loaded_networks
return lora.loaded_loras
def load_lora_ext():
global is_sd_1_5
global lora_Linear_forward
global lora_Linear_load_state_dict
global lora_Conv2d_forward
global lora_Conv2d_load_state_dict
global lora_MultiheadAttention_forward
global lora_MultiheadAttention_load_state_dict
if lora_Linear_forward is not None:
return
import lora
is_sd_1_5 = False
try:
import networks
is_sd_1_5 = True
except ImportError:
pass
if is_sd_1_5:
if hasattr(networks, "network_Linear_forward"):
lora_Linear_forward = networks.network_Linear_forward
if hasattr(networks, "network_Linear_load_state_dict"):
lora_Linear_load_state_dict = networks.network_Linear_load_state_dict
if hasattr(networks, "network_Conv2d_forward"):
lora_Conv2d_forward = networks.network_Conv2d_forward
if hasattr(networks, "network_Conv2d_load_state_dict"):
lora_Conv2d_load_state_dict = networks.network_Conv2d_load_state_dict
if hasattr(networks, "network_MultiheadAttention_forward"):
lora_MultiheadAttention_forward = networks.network_MultiheadAttention_forward
if hasattr(networks, "network_MultiheadAttention_load_state_dict"):
lora_MultiheadAttention_load_state_dict = networks.network_MultiheadAttention_load_state_dict
else:
if hasattr(networks, "network_Linear_forward"):
lora_Linear_forward = lora.lora_Linear_forward
if hasattr(networks, "network_Linear_load_state_dict"):
lora_Linear_load_state_dict = lora.lora_Linear_load_state_dict
if hasattr(networks, "network_Conv2d_forward"):
lora_Conv2d_forward = lora.lora_Conv2d_forward
if hasattr(networks, "network_Conv2d_load_state_dict"):
lora_Conv2d_load_state_dict = lora.lora_Conv2d_load_state_dict
if hasattr(networks, "network_MultiheadAttention_forward"):
lora_MultiheadAttention_forward = lora.lora_MultiheadAttention_forward
if hasattr(networks, "network_MultiheadAttention_load_state_dict"):
lora_MultiheadAttention_load_state_dict = lora.lora_MultiheadAttention_load_state_dict

View File

@ -6,6 +6,7 @@ import gradio as gr
import composable_lora
import composable_lora_function_handler
import lora_ext
import modules.scripts as scripts
from modules import script_callbacks
from modules.processing import StableDiffusionProcessing
@ -56,8 +57,8 @@ if hasattr(torch.nn, 'Linear_forward_before_lyco'):
else:
composable_lora.lyco_notfound = True
torch.nn.Linear.forward = composable_lora.lora_Linear_forward
torch.nn.Conv2d.forward = composable_lora.lora_Conv2d_forward
#torch.nn.Linear.forward = composable_lora.lora_Linear_forward
#torch.nn.Conv2d.forward = composable_lora.lora_Conv2d_forward
def check_install_state():
if not hasattr(composable_lora, "noop"):
@ -97,6 +98,11 @@ class ComposableLoraScript(scripts.Script):
opt_uc_text_model_encoder: bool, opt_uc_diffusion_model:
bool, opt_plot_lora_weight: bool, opt_single_no_uc:
bool, opt_hires_step_as_global: bool):
lora_ext.load_lora_ext()
if lora_ext.is_sd_1_5:
import composable_lycoris
if composable_lycoris.has_webui_lycoris:
print("Error! in sd webui 1.5, composable-lora not support with sd-webui-lycoris extension.")
composable_lora.enabled = enabled
composable_lora.opt_uc_text_model_encoder = opt_uc_text_model_encoder
composable_lora.opt_uc_diffusion_model = opt_uc_diffusion_model