reduce the chance of crashing in 1.5.......
parent
e8f461f0e9
commit
73de4b8f0f
|
|
@ -4,6 +4,7 @@ import torch
|
|||
import composable_lora_step
|
||||
import composable_lycoris
|
||||
import plot_helper
|
||||
import lora_ext
|
||||
from modules import extra_networks, devices
|
||||
|
||||
def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention], input, res):
|
||||
|
|
@ -13,6 +14,7 @@ def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
|
|||
global should_print
|
||||
global first_log_drawing
|
||||
global drawing_lora_first_index
|
||||
|
||||
import lora
|
||||
|
||||
if composable_lycoris.has_webui_lycoris:
|
||||
|
|
@ -28,20 +30,23 @@ def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
|
|||
log_lora()
|
||||
drawing_lora_first_index = drawing_data[0]
|
||||
|
||||
if len(lora.loaded_loras) == 0:
|
||||
if len(lora_ext.get_loaded_lora()) == 0:
|
||||
return res
|
||||
|
||||
if hasattr(devices, "cond_cast_unet"):
|
||||
input = devices.cond_cast_unet(input)
|
||||
|
||||
lora_layer_name_loading : Optional[str] = getattr(compvis_module, 'lora_layer_name', None)
|
||||
if lora_layer_name_loading is None:
|
||||
lora_layer_name_loading = getattr(compvis_module, 'network_layer_name', None)
|
||||
if lora_layer_name_loading is None:
|
||||
return res
|
||||
#let it type is actually a string
|
||||
lora_layer_name : str = str(lora_layer_name_loading)
|
||||
del lora_layer_name_loading
|
||||
|
||||
num_loras = len(lora.loaded_loras)
|
||||
lora_loaded_loras = lora_ext.get_loaded_lora()
|
||||
num_loras = len(lora_loaded_loras)
|
||||
if composable_lycoris.has_webui_lycoris:
|
||||
num_loras += len(lycoris.loaded_lycos)
|
||||
|
||||
|
|
@ -51,7 +56,7 @@ def lora_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.n
|
|||
tmp_check_loras = [] #store which lora are already apply
|
||||
tmp_check_loras.clear()
|
||||
|
||||
for m_lora in lora.loaded_loras:
|
||||
for m_lora in lora_loaded_loras:
|
||||
module = m_lora.modules.get(lora_layer_name, None)
|
||||
if module is None:
|
||||
#fix the lyCORIS issue
|
||||
|
|
@ -164,7 +169,7 @@ def add_step_counters():
|
|||
|
||||
def log_lora():
|
||||
import lora
|
||||
loaded_loras = lora.loaded_loras
|
||||
loaded_loras = lora_ext.get_loaded_lora()
|
||||
loaded_lycos = []
|
||||
if composable_lycoris.has_webui_lycoris:
|
||||
import lycoris
|
||||
|
|
@ -431,9 +436,11 @@ def lora_Linear_forward(self, input):
|
|||
if old_lyco_count > 0 and lyco_count <= 0:
|
||||
clear_cache_lora(self, True)
|
||||
self.old_lyco_count = lyco_count
|
||||
torch.nn.Linear_forward_before_lyco = lora.lora_Linear_forward
|
||||
lora_ext.load_lora_ext()
|
||||
torch.nn.Linear_forward_before_lyco = lora_ext.lora_Linear_forward
|
||||
torch.nn.Linear_forward_before_network = Linear_forward_before_clora
|
||||
#if lyco_count <= 0:
|
||||
# return lora.lora_Linear_forward(self, input)
|
||||
# return lora_ext.lora_Linear_forward(self, input)
|
||||
if 'lyco_notfound' in locals() or 'lyco_notfound' in globals():
|
||||
if lyco_notfound:
|
||||
backup_Linear_forward = torch.nn.Linear_forward_before_lora
|
||||
|
|
@ -441,7 +448,6 @@ def lora_Linear_forward(self, input):
|
|||
result = lycoris.lyco_Linear_forward(self, input)
|
||||
torch.nn.Linear_forward_before_lora = backup_Linear_forward
|
||||
return result
|
||||
|
||||
return lycoris.lyco_Linear_forward(self, input)
|
||||
clear_cache_lora(self, False)
|
||||
if (not self.weight.is_cuda) and input.is_cuda: #if variables not on the same device (between cpu and gpu)
|
||||
|
|
@ -468,9 +474,11 @@ def lora_Conv2d_forward(self, input):
|
|||
if old_lyco_count > 0 and lyco_count <= 0:
|
||||
clear_cache_lora(self, True)
|
||||
self.old_lyco_count = lyco_count
|
||||
torch.nn.Conv2d_forward_before_lyco = lora.lora_Conv2d_forward
|
||||
lora_ext.load_lora_ext()
|
||||
torch.nn.Conv2d_forward_before_lyco = lora_ext.lora_Conv2d_forward
|
||||
torch.nn.Conv2d_forward_before_network = Conv2d_forward_before_clora
|
||||
#if lyco_count <= 0:
|
||||
# return lora.lora_Conv2d_forward(self, input)
|
||||
# return lora_ext.lora_Conv2d_forward(self, input)
|
||||
if 'lyco_notfound' in locals() or 'lyco_notfound' in globals():
|
||||
if lyco_notfound:
|
||||
backup_Conv2d_forward = torch.nn.Conv2d_forward_before_lora
|
||||
|
|
@ -505,9 +513,12 @@ def lora_MultiheadAttention_forward(self, input):
|
|||
if old_lyco_count > 0 and lyco_count <= 0:
|
||||
clear_cache_lora(self, True)
|
||||
self.old_lyco_count = lyco_count
|
||||
torch.nn.MultiheadAttention_forward_before_lyco = lora.lora_MultiheadAttention_forward
|
||||
lora_ext.load_lora_ext()
|
||||
torch.nn.MultiheadAttention_forward_before_lyco = lora_ext.lora_MultiheadAttention_forward
|
||||
torch.nn.MultiheadAttention_forward_before_network = MultiheadAttention_forward_before_clora
|
||||
|
||||
#if lyco_count <= 0:
|
||||
# return lora.lora_MultiheadAttention_forward(self, input)
|
||||
# return lora_ext.lora_MultiheadAttention_forward(self, input)
|
||||
if 'lyco_notfound' in locals() or 'lyco_notfound' in globals():
|
||||
if lyco_notfound:
|
||||
backup_MultiheadAttention_forward = torch.nn.MultiheadAttention_forward_before_lora
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional, Union
|
||||
import re
|
||||
import torch
|
||||
import lora_ext
|
||||
from modules import shared, devices
|
||||
|
||||
#support for <lyco:MODEL>
|
||||
|
|
@ -22,7 +23,7 @@ def lycoris_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torc
|
|||
del lycoris_layer_name_loading
|
||||
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(lycoris_layer_name, None)
|
||||
num_loras = len(lora.loaded_loras) + len(lycoris.loaded_lycos)
|
||||
num_loras = len(lora_ext.get_loaded_lora()) + len(lycoris.loaded_lycos)
|
||||
|
||||
if lora_controller.text_model_encoder_counter == -1:
|
||||
lora_controller.text_model_encoder_counter = len(lora_controller.prompt_loras) * num_loras
|
||||
|
|
@ -105,7 +106,10 @@ def get_lora_patch(module, input, res, lora_layer_name):
|
|||
if inference is not None:
|
||||
return inference
|
||||
else:
|
||||
converted_module = convert_lycoris(module, shared.sd_model.lora_layer_mapping.get(lora_layer_name, None))
|
||||
if hasattr(shared.sd_model, "network_layer_mapping"):
|
||||
converted_module = convert_lycoris(module, shared.sd_model.network_layer_mapping.get(lora_layer_name, None))
|
||||
else:
|
||||
converted_module = convert_lycoris(module, shared.sd_model.lora_layer_mapping.get(lora_layer_name, None))
|
||||
if converted_module is not None:
|
||||
return get_lora_inference(converted_module, input)
|
||||
else:
|
||||
|
|
@ -340,7 +344,8 @@ def convert_lycoris(lycoris_module, sd_module):
|
|||
result_module = getattr(lycoris_module, 'lyco_converted_lora_module', None)
|
||||
if result_module is not None:
|
||||
return result_module
|
||||
if lycoris_module.__class__.__name__ == "LycoUpDownModule" or lycoris_module.__class__.__name__ == "LoraUpDownModule":
|
||||
if lycoris_module.__class__.__name__ == "LycoUpDownModule" or lycoris_module.__class__.__name__ == "LoraUpDownModule"\
|
||||
or lycoris_module.__class__.__name__ == "NetworkModuleLora":
|
||||
result_module = LoraUpDownModule()
|
||||
if (type(sd_module) == torch.nn.Linear
|
||||
or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear
|
||||
|
|
@ -365,7 +370,7 @@ def convert_lycoris(lycoris_module, sd_module):
|
|||
result_module.up_model.weight,
|
||||
result_module.inference
|
||||
)
|
||||
elif lycoris_module.__class__.__name__ == "FullModule":
|
||||
elif lycoris_module.__class__.__name__ == "FullModule" or lycoris_module.__class__.__name__ == "NetworkModuleFull":
|
||||
result_module = FullModule()
|
||||
result_module.weight = lycoris_module.weight#.to(device=devices.device, dtype=devices.dtype)
|
||||
result_module.alpha = lycoris_module.alpha
|
||||
|
|
@ -388,7 +393,7 @@ def convert_lycoris(lycoris_module, sd_module):
|
|||
}
|
||||
setattr(lycoris_module, "lyco_converted_lora_module", result_module)
|
||||
return result_module
|
||||
elif lycoris_module.__class__.__name__ == "IA3Module":
|
||||
elif lycoris_module.__class__.__name__ == "IA3Module" or lycoris_module.__class__.__name__ == "NetworkModuleIa3":
|
||||
result_module = IA3Module()
|
||||
result_module.w = lycoris_module.w
|
||||
result_module.alpha = lycoris_module.alpha
|
||||
|
|
@ -401,7 +406,8 @@ def convert_lycoris(lycoris_module, sd_module):
|
|||
result_module.op = torch.nn.functional.linear
|
||||
elif type(sd_module) == torch.nn.Conv2d:
|
||||
result_module.op = torch.nn.functional.conv2d
|
||||
elif lycoris_module.__class__.__name__ == "LycoHadaModule" or lycoris_module.__class__.__name__ == "LoraHadaModule":
|
||||
elif lycoris_module.__class__.__name__ == "LycoHadaModule" or lycoris_module.__class__.__name__ == "LoraHadaModule"\
|
||||
or lycoris_module.__class__.__name__ == "NetworkModuleHada":
|
||||
result_module = LoraHadaModule()
|
||||
result_module.t1 = lycoris_module.t1
|
||||
result_module.w1a = lycoris_module.w1a
|
||||
|
|
@ -427,7 +433,8 @@ def convert_lycoris(lycoris_module, sd_module):
|
|||
'stride': sd_module.stride,
|
||||
'padding': sd_module.padding
|
||||
}
|
||||
elif lycoris_module.__class__.__name__ == "LycoKronModule" or lycoris_module.__class__.__name__ == "LoraKronModule" :
|
||||
elif lycoris_module.__class__.__name__ == "LycoKronModule" or lycoris_module.__class__.__name__ == "LoraKronModule"\
|
||||
or lycoris_module.__class__.__name__ == "NetworkModuleLokr" :
|
||||
result_module = LoraKronModule()
|
||||
result_module.w1 = lycoris_module.w1
|
||||
result_module.w1a = lycoris_module.w1a
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
lora_Linear_forward = None
|
||||
lora_Linear_load_state_dict = None
|
||||
lora_Conv2d_forward = None
|
||||
lora_Conv2d_load_state_dict = None
|
||||
lora_MultiheadAttention_forward = None
|
||||
lora_MultiheadAttention_load_state_dict = None
|
||||
is_sd_1_5 = False
|
||||
def get_loaded_lora():
|
||||
global is_sd_1_5
|
||||
if lora_Linear_forward is None:
|
||||
load_lora_ext()
|
||||
import lora
|
||||
try:
|
||||
import networks
|
||||
is_sd_1_5 = True
|
||||
except ImportError:
|
||||
pass
|
||||
if is_sd_1_5:
|
||||
return networks.loaded_networks
|
||||
return lora.loaded_loras
|
||||
|
||||
def load_lora_ext():
|
||||
global is_sd_1_5
|
||||
global lora_Linear_forward
|
||||
global lora_Linear_load_state_dict
|
||||
global lora_Conv2d_forward
|
||||
global lora_Conv2d_load_state_dict
|
||||
global lora_MultiheadAttention_forward
|
||||
global lora_MultiheadAttention_load_state_dict
|
||||
if lora_Linear_forward is not None:
|
||||
return
|
||||
import lora
|
||||
is_sd_1_5 = False
|
||||
try:
|
||||
import networks
|
||||
is_sd_1_5 = True
|
||||
except ImportError:
|
||||
pass
|
||||
if is_sd_1_5:
|
||||
if hasattr(networks, "network_Linear_forward"):
|
||||
lora_Linear_forward = networks.network_Linear_forward
|
||||
if hasattr(networks, "network_Linear_load_state_dict"):
|
||||
lora_Linear_load_state_dict = networks.network_Linear_load_state_dict
|
||||
if hasattr(networks, "network_Conv2d_forward"):
|
||||
lora_Conv2d_forward = networks.network_Conv2d_forward
|
||||
if hasattr(networks, "network_Conv2d_load_state_dict"):
|
||||
lora_Conv2d_load_state_dict = networks.network_Conv2d_load_state_dict
|
||||
if hasattr(networks, "network_MultiheadAttention_forward"):
|
||||
lora_MultiheadAttention_forward = networks.network_MultiheadAttention_forward
|
||||
if hasattr(networks, "network_MultiheadAttention_load_state_dict"):
|
||||
lora_MultiheadAttention_load_state_dict = networks.network_MultiheadAttention_load_state_dict
|
||||
else:
|
||||
if hasattr(networks, "network_Linear_forward"):
|
||||
lora_Linear_forward = lora.lora_Linear_forward
|
||||
if hasattr(networks, "network_Linear_load_state_dict"):
|
||||
lora_Linear_load_state_dict = lora.lora_Linear_load_state_dict
|
||||
if hasattr(networks, "network_Conv2d_forward"):
|
||||
lora_Conv2d_forward = lora.lora_Conv2d_forward
|
||||
if hasattr(networks, "network_Conv2d_load_state_dict"):
|
||||
lora_Conv2d_load_state_dict = lora.lora_Conv2d_load_state_dict
|
||||
if hasattr(networks, "network_MultiheadAttention_forward"):
|
||||
lora_MultiheadAttention_forward = lora.lora_MultiheadAttention_forward
|
||||
if hasattr(networks, "network_MultiheadAttention_load_state_dict"):
|
||||
lora_MultiheadAttention_load_state_dict = lora.lora_MultiheadAttention_load_state_dict
|
||||
|
|
@ -6,6 +6,7 @@ import gradio as gr
|
|||
|
||||
import composable_lora
|
||||
import composable_lora_function_handler
|
||||
import lora_ext
|
||||
import modules.scripts as scripts
|
||||
from modules import script_callbacks
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
|
|
@ -56,8 +57,8 @@ if hasattr(torch.nn, 'Linear_forward_before_lyco'):
|
|||
else:
|
||||
composable_lora.lyco_notfound = True
|
||||
|
||||
torch.nn.Linear.forward = composable_lora.lora_Linear_forward
|
||||
torch.nn.Conv2d.forward = composable_lora.lora_Conv2d_forward
|
||||
#torch.nn.Linear.forward = composable_lora.lora_Linear_forward
|
||||
#torch.nn.Conv2d.forward = composable_lora.lora_Conv2d_forward
|
||||
|
||||
def check_install_state():
|
||||
if not hasattr(composable_lora, "noop"):
|
||||
|
|
@ -97,6 +98,11 @@ class ComposableLoraScript(scripts.Script):
|
|||
opt_uc_text_model_encoder: bool, opt_uc_diffusion_model:
|
||||
bool, opt_plot_lora_weight: bool, opt_single_no_uc:
|
||||
bool, opt_hires_step_as_global: bool):
|
||||
lora_ext.load_lora_ext()
|
||||
if lora_ext.is_sd_1_5:
|
||||
import composable_lycoris
|
||||
if composable_lycoris.has_webui_lycoris:
|
||||
print("Error! in sd webui 1.5, composable-lora not support with sd-webui-lycoris extension.")
|
||||
composable_lora.enabled = enabled
|
||||
composable_lora.opt_uc_text_model_encoder = opt_uc_text_model_encoder
|
||||
composable_lora.opt_uc_diffusion_model = opt_uc_diffusion_model
|
||||
|
|
|
|||
Loading…
Reference in New Issue