stable-diffusion-webui-comp.../composable_lycoris.py

563 lines
21 KiB
Python

from typing import Optional, Union
import re
import torch
import lora_ext
from modules import shared, devices
#support for <lyco:MODEL>
def lycoris_forward(compvis_module: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention], input, res):
import composable_lora as lora_controller
import lora
import lycoris
if len(lycoris.loaded_lycos) == 0:
return res
if hasattr(devices, "cond_cast_unet"):
input = devices.cond_cast_unet(input)
lycoris_layer_name_loading : Optional[str] = getattr(compvis_module, 'lyco_layer_name', None)
if lycoris_layer_name_loading is None:
return res
#let it type is actually a string
lycoris_layer_name : str = str(lycoris_layer_name_loading)
del lycoris_layer_name_loading
sd_module = shared.sd_model.lora_layer_mapping.get(lycoris_layer_name, None)
num_loras = len(lora_ext.get_loaded_lora()) + len(lycoris.loaded_lycos)
if lora_controller.text_model_encoder_counter == -1:
lora_controller.text_model_encoder_counter = len(lora_controller.prompt_loras) * num_loras
tmp_check_loras = [] #store which lora are already apply
tmp_check_loras.clear()
for m_lycoris in lycoris.loaded_lycos:
module = m_lycoris.modules.get(lycoris_layer_name, None)
if module is None:
#fix the lyCORIS issue
check_lycoris_end_layer(lycoris_layer_name, res, num_loras)
continue
current_lora = normalize_lora_name(m_lycoris.name)
lora_already_used = False
if current_lora in tmp_check_loras:
lora_already_used = True
#store the applied lora into list
tmp_check_loras.append(current_lora)
if lora_already_used:
check_lycoris_end_layer(lycoris_layer_name, res, num_loras)
continue
converted_module = convert_lycoris(module, sd_module)
if converted_module is None:
check_lycoris_end_layer(lycoris_layer_name, res, num_loras)
continue
patch = get_lora_patch(converted_module, input, res, lycoris_layer_name)
alpha = get_lora_alpha(converted_module, 1.0)
num_prompts = len(lora_controller.prompt_loras)
# print(f"lora.name={m_lora.name} lora.mul={m_lora.multiplier} alpha={alpha} pat.shape={patch.shape}")
res = lora_controller.apply_composable_lora(lycoris_layer_name, m_lycoris, converted_module, "lyco", patch, alpha, res, num_loras, num_prompts)
return res
def composable_forward(module, patch, alpha, multiplier, res):
if hasattr(module, 'composable_forward'):
return module.composable_forward(patch, alpha, multiplier, res)
return res + multiplier * alpha * patch
re_lora_block_weight = re.compile(r"[_\s]*added[_\s]*by[_\s]*lora[_\s]*block[_\s]*weight[_\s]*.*$")
def normalize_lora_name(lora_name):
result = re.sub(r"[_\s]*added[_\s]*by[_\s]*lora[_\s]*block[_\s]*weight[_\s]*.*$", "", lora_name)
return result
def get_lora_inference(module, input):
if hasattr(module, 'inference'): #support for lyCORIS
return module.inference(input)
elif hasattr(module, 'up'): #LoRA
if hasattr(module.up, "to"):
module.up.to(device=devices.device)
if hasattr(module.down, "to"):
module.down.to(device=devices.device)
return module.up(module.down(input))
else:
return None
def get_lora_patch(module, input, res, lora_layer_name):
if is_loha(module):
if input.is_cuda: #if is cuda, pass to cuda; otherwise do nothing
pass_loha_to_gpu(module)
if getattr(shared.opts, "lora_apply_to_outputs", False) and res.shape == input.shape:
inference = get_lora_inference(module, res)
if inference is not None:
return inference
else:
converted_module = convert_lycoris(module, shared.sd_model.lora_layer_mapping.get(lora_layer_name, None))
if converted_module is not None:
return get_lora_inference(converted_module, res)
else:
raise NotImplementedError(
"Your settings, extensions or models are not compatible with each other."
)
else:
inference = get_lora_inference(module, input)
if inference is not None:
return inference
else:
if hasattr(shared.sd_model, "network_layer_mapping"):
converted_module = convert_lycoris(module, shared.sd_model.network_layer_mapping.get(lora_layer_name, None))
else:
converted_module = convert_lycoris(module, shared.sd_model.lora_layer_mapping.get(lora_layer_name, None))
if converted_module is not None:
return get_lora_inference(converted_module, input)
else:
raise NotImplementedError(
"Your settings, extensions or models are not compatible with each other."
)
def get_lora_alpha(module, default_val):
if hasattr(module, 'up'):
return (module.alpha / module.up.weight.shape[1] if module.alpha else default_val)
elif hasattr(module, 'dim'): #support for lyCORIS
return (module.alpha / module.dim if module.alpha else default_val)
else:
return default_val
def check_lycoris_end_layer(lora_layer_name: str, res, num_loras):
if lora_layer_name.endswith("_11_mlp_fc2") or lora_layer_name.endswith("_11_1_proj_out"):
import composable_lora as lora_controller
if lora_layer_name.endswith("_11_mlp_fc2"): # lyCORIS maybe doesn't has _11_mlp_fc2 layer
lora_controller.text_model_encoder_counter += 1
if lora_controller.text_model_encoder_counter == (len(lora_controller.prompt_loras) + lora_controller.num_batches) * num_loras:
lora_controller.text_model_encoder_counter = 0
if lora_layer_name.endswith("_11_1_proj_out"): # lyCORIS maybe doesn't has _11_1_proj_out layer
lora_controller.diffusion_model_counter += res.shape[0]
if lora_controller.diffusion_model_counter >= (len(lora_controller.prompt_loras) + lora_controller.num_batches) * num_loras:
lora_controller.diffusion_model_counter = 0
lora_controller.add_step_counters()
def lycoris_get_multiplier(lycoris_model, lora_layer_name):
multiplier = 1.0
if hasattr(lycoris_model, 'te_multiplier'):
multiplier = (
lycoris_model.te_multiplier if 'transformer' in lora_layer_name[:20]
else lycoris_model.unet_multiplier
)
elif hasattr(lycoris_model, 'multiplier'):
multiplier = getattr(lycoris_model, 'multiplier', 1.0)
return multiplier
def lycoris_get_multiplier_normalized(lycoris_model, lora_layer_name):
multiplier = 1.0
if hasattr(lycoris_model, 'te_multiplier'):
te_multiplier = 1.0
unet_multiplier = lycoris_model.unet_multiplier / lycoris_model.te_multiplier
multiplier = (
te_multiplier if 'transformer' in lora_layer_name[:20]
else unet_multiplier
)
return multiplier
class FakeModule(torch.nn.Module):
def __init__(self, weight, func):
super().__init__()
self.weight = weight
self.func = func
def forward(self, x):
return self.func(x)
class FullModule:
def __init__(self):
self.weight = None
self.alpha = None
self.op = None
self.extra_args = {}
self.shape = None
self.up = None
def down(self, x):
return x
def inference(self, x):
return self.op(x, self.weight, **self.extra_args)
class IA3Module:
def __init__(self):
self.w = None
self.alpha = None
self.on_input = None
self.shape = None
self.op = None
self.extra_args = {}
def down(self, x):
return x
def inference(self, x):
return self.op(x, self.w, **self.extra_args)
def composable_forward(self, patch, alpha, multiplier, res):
patch = patch.to(res.dtype)
return res * (1 + patch * alpha * multiplier)
class LoraUpDownModule:
def __init__(self):
self.up_model = None
self.mid_model = None
self.down_model = None
self.alpha = None
self.dim = None
self.op = None
self.extra_args = {}
self.shape = None
self.bias = None
self.up = None
def down(self, x):
return x
def inference(self, x):
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
out_dim = self.up_model.weight.size(0)
rank = self.down_model.weight.size(0)
rebuild_weight = (
self.up_model.weight.reshape(out_dim, -1) @ self.down_model.weight.reshape(rank, -1)
+ self.bias
).reshape(self.shape)
return self.op(
x, rebuild_weight,
bias=None,
**self.extra_args
)
else:
if self.mid_model is None:
return self.up_model(self.down_model(x))
else:
return self.up_model(self.mid_model(self.down_model(x)))
def make_weight_cp(t, wa, wb):
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
class LoraHadaModule:
def __init__(self):
self.t1 = None
self.w1a = None
self.w1b = None
self.t2 = None
self.w2a = None
self.w2b = None
self.alpha = None
self.dim = None
self.op = None
self.extra_args = {}
self.shape = None
self.bias = None
self.up = None
def down(self, x):
return x
def inference(self, x):
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
bias = self.bias
else:
bias = 0
if self.t1 is None:
return self.op(
x,
((self.w1a @ self.w1b) * (self.w2a @ self.w2b) + bias).view(self.shape),
bias=None,
**self.extra_args
)
else:
return self.op(
x,
(make_weight_cp(self.t1, self.w1a, self.w1b)
* make_weight_cp(self.t2, self.w2a, self.w2b) + bias).view(self.shape),
bias=None,
**self.extra_args
)
def make_kron(orig_shape, w1, w2):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
return torch.kron(w1, w2).reshape(orig_shape)
class LoraKronModule:
def __init__(self):
self.w1 = None
self.w1a = None
self.w1b = None
self.w2 = None
self.t2 = None
self.w2a = None
self.w2b = None
self._alpha = None
self.dim = None
self.op = None
self.extra_args = {}
self.shape = None
self.bias = None
self.up = None
@property
def alpha(self):
if self.w1a is None and self.w2a is None:
return None
else:
return self._alpha
@alpha.setter
def alpha(self, x):
self._alpha = x
def down(self, x):
return x
def inference(self, x):
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
bias = self.bias
else:
bias = 0
if self.t2 is None:
return self.op(
x,
(torch.kron(self.w1, self.w2a@self.w2b) + bias).view(self.shape),
**self.extra_args
)
else:
# will raise NotImplemented Error
return self.op(
x,
(torch.kron(self.w1, make_weight_cp(self.t2, self.w2a, self.w2b)) + bias).view(self.shape),
**self.extra_args
)
def convert_lycoris(lycoris_module, sd_module):
result_module = getattr(lycoris_module, 'lyco_converted_lora_module', None)
if result_module is not None:
return result_module
if lycoris_module.__class__.__name__ == "LycoUpDownModule" or lycoris_module.__class__.__name__ == "LoraUpDownModule"\
or lycoris_module.__class__.__name__ == "NetworkModuleLora":
result_module = LoraUpDownModule()
if (type(sd_module) == torch.nn.Linear
or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear
or type(sd_module) == torch.nn.MultiheadAttention):
result_module.op = torch.nn.functional.linear
elif type(sd_module) == torch.nn.Conv2d:
result_module.op = torch.nn.functional.conv2d
result_module.extra_args = {
'stride': sd_module.stride,
'padding': sd_module.padding
}
else:
return None
result_module.up_model = lycoris_module.up_model
result_module.mid_model = lycoris_module.mid_model
result_module.down_model = lycoris_module.down_model
result_module.alpha = lycoris_module.alpha
result_module.dim = lycoris_module.dim
result_module.shape = lycoris_module.shape
result_module.bias = lycoris_module.bias
result_module.up = FakeModule(
result_module.up_model.weight,
result_module.inference
)
elif lycoris_module.__class__.__name__ == "FullModule" or lycoris_module.__class__.__name__ == "NetworkModuleFull":
result_module = FullModule()
result_module.weight = lycoris_module.weight#.to(device=devices.device, dtype=devices.dtype)
result_module.alpha = lycoris_module.alpha
result_module.shape = lycoris_module.shape
result_module.up = FakeModule(
result_module.weight,
result_module.inference
)
if len(result_module.weight.shape)==2:
result_module.op = torch.nn.functional.linear
result_module.extra_args = {
'bias': None
}
else:
result_module.op = torch.nn.functional.conv2d
result_module.extra_args = {
'stride': sd_module.stride,
'padding': sd_module.padding,
'bias': None
}
setattr(lycoris_module, "lyco_converted_lora_module", result_module)
return result_module
elif lycoris_module.__class__.__name__ == "IA3Module" or lycoris_module.__class__.__name__ == "NetworkModuleIa3":
result_module = IA3Module()
result_module.w = lycoris_module.w
result_module.alpha = lycoris_module.alpha
result_module.on_input = lycoris_module.on_input
if hasattr(sd_module, 'weight'):
result_module.shape = sd_module.weight.shape
if (type(sd_module) == torch.nn.Linear
or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear
or type(sd_module) == torch.nn.MultiheadAttention):
result_module.op = torch.nn.functional.linear
elif type(sd_module) == torch.nn.Conv2d:
result_module.op = torch.nn.functional.conv2d
elif lycoris_module.__class__.__name__ == "LycoHadaModule" or lycoris_module.__class__.__name__ == "LoraHadaModule"\
or lycoris_module.__class__.__name__ == "NetworkModuleHada":
result_module = LoraHadaModule()
result_module.t1 = lycoris_module.t1
result_module.w1a = lycoris_module.w1a
result_module.w1b = lycoris_module.w1b
result_module.t2 = lycoris_module.t2
result_module.w2a = lycoris_module.w2a
result_module.w2b = lycoris_module.w2b
result_module.alpha = lycoris_module.alpha
result_module.dim = lycoris_module.dim
result_module.shape = lycoris_module.shape
result_module.bias = lycoris_module.bias
result_module.up = FakeModule(
result_module.t1 if result_module.t1 is not None else result_module.w1a,
result_module.inference
)
if (type(sd_module) == torch.nn.Linear
or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear
or type(sd_module) == torch.nn.MultiheadAttention):
result_module.op = torch.nn.functional.linear
elif type(sd_module) == torch.nn.Conv2d:
result_module.op = torch.nn.functional.conv2d
result_module.extra_args = {
'stride': sd_module.stride,
'padding': sd_module.padding
}
elif lycoris_module.__class__.__name__ == "LycoKronModule" or lycoris_module.__class__.__name__ == "LoraKronModule"\
or lycoris_module.__class__.__name__ == "NetworkModuleLokr" :
result_module = LoraKronModule()
result_module.w1 = lycoris_module.w1
result_module.w1a = lycoris_module.w1a
result_module.w1b = lycoris_module.w1b
result_module.w2 = lycoris_module.w2
result_module.t2 = lycoris_module.t2
result_module.w2a = lycoris_module.w2a
result_module.w2b = lycoris_module.w2b
result_module._alpha = lycoris_module._alpha
result_module.dim = lycoris_module.dim
result_module.shape = lycoris_module.shape
result_module.bias = lycoris_module.bias
result_module.up = FakeModule(
result_module.w1a if result_module.w1a is not None else result_module.w2a,
result_module.inference
)
if (any(isinstance(sd_module, torch_layer) for torch_layer in
[torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention])):
result_module.op = torch.nn.functional.linear
elif isinstance(sd_module, torch.nn.Conv2d):
result_module.op = torch.nn.functional.conv2d
result_module.extra_args = {
'stride': sd_module.stride,
'padding': sd_module.padding
}
if result_module is not None:
setattr(lycoris_module, "lyco_converted_lora_module", result_module)
return result_module
return None
def is_loha(m_lora):
return hasattr(m_lora, 'w1a') or hasattr(m_lora, 'w1b') or hasattr(m_lora, 'w2a') or hasattr(m_lora, 'w2b')
def pass_loha_to_gpu(m_loha):
if hasattr(m_loha, 'bias'):
if isinstance(m_loha.bias, torch.Tensor):
if not m_loha.bias.is_cuda:
to_cuda = m_loha.bias.to(device=devices.device)
to_del = m_loha.bias
m_loha.bias = None
del to_del
del m_loha.bias
m_loha.bias = to_cuda
if hasattr(m_loha, 't1'):
if isinstance(m_loha.t1, torch.Tensor):
if not m_loha.t1.is_cuda:
to_cuda = m_loha.t1.to(device=devices.device)
to_del = m_loha.t1
m_loha.t1 = None
del to_del
del m_loha.t1
m_loha.t1 = to_cuda
if hasattr(m_loha, 't2'):
if isinstance(m_loha.t2, torch.Tensor):
if not m_loha.t2.is_cuda:
to_cuda = m_loha.t2.to(device=devices.device)
to_del = m_loha.t2
m_loha.t2 = None
del to_del
del m_loha.t2
m_loha.t2 = to_cuda
if hasattr(m_loha, 'w'):
if isinstance(m_loha.w, torch.Tensor):
if not m_loha.w.is_cuda:
to_cuda = m_loha.w.to(device=devices.device)
to_del = m_loha.w
m_loha.w = None
del to_del
del m_loha.w
m_loha.w = to_cuda
if hasattr(m_loha, 'w1'):
if isinstance(m_loha.w1, torch.Tensor):
if not m_loha.w1.is_cuda:
to_cuda = m_loha.w1.to(device=devices.device)
to_del = m_loha.w1
m_loha.w1 = None
del to_del
del m_loha.w1
m_loha.w1 = to_cuda
if hasattr(m_loha, 'w1a'):
if isinstance(m_loha.w1a, torch.Tensor):
if not m_loha.w1a.is_cuda:
to_cuda = m_loha.w1a.to(device=devices.device)
to_del = m_loha.w1a
m_loha.w1a = None
del to_del
del m_loha.w1a
m_loha.w1a = to_cuda
if hasattr(m_loha, 'w1b'):
if isinstance(m_loha.w1b, torch.Tensor):
if not m_loha.w1b.is_cuda:
to_cuda = m_loha.w1b.to(device=devices.device)
to_del = m_loha.w1b
m_loha.w1b = None
del to_del
del m_loha.w1b
m_loha.w1b = to_cuda
if hasattr(m_loha, 'w2'):
if isinstance(m_loha.w2, torch.Tensor):
if not m_loha.w2.is_cuda:
to_cuda = m_loha.w2.to(device=devices.device)
to_del = m_loha.w2
m_loha.w2 = None
del to_del
del m_loha.w2
m_loha.w2 = to_cuda
if hasattr(m_loha, 'w2a'):
if isinstance(m_loha.w2a, torch.Tensor):
if not m_loha.w2a.is_cuda:
to_cuda = m_loha.w2a.to(device=devices.device)
to_del = m_loha.w2a
m_loha.w2a = None
del to_del
del m_loha.w2a
m_loha.w2a = to_cuda
if hasattr(m_loha, 'w2b'):
if isinstance(m_loha.w2b, torch.Tensor):
if not m_loha.w2b.is_cuda:
to_cuda = m_loha.w2b.to(device=devices.device)
to_del = m_loha.w2b
m_loha.w2b = None
del to_del
del m_loha.w2b
m_loha.w2b = to_cuda
has_webui_lycoris : bool = False