549 lines
20 KiB
Python
549 lines
20 KiB
Python
from typing import Optional
|
|
import re
|
|
import torch
|
|
from modules import shared
|
|
|
|
#support for <lyco:MODEL>
|
|
def lycoris_forward(compvis_module, input, res):
|
|
import composable_lora as lora_controller
|
|
import lora
|
|
import lycoris
|
|
if len(lycoris.loaded_lycos) == 0:
|
|
return res
|
|
|
|
lycoris_layer_name_loading : Optional[str] = getattr(compvis_module, 'lyco_layer_name', None)
|
|
if lycoris_layer_name_loading is None:
|
|
return res
|
|
#let it type is actually a string
|
|
lycoris_layer_name : str = str(lycoris_layer_name_loading)
|
|
del lycoris_layer_name_loading
|
|
|
|
sd_module = shared.sd_model.lora_layer_mapping.get(lycoris_layer_name, None)
|
|
num_loras = len(lora.loaded_loras) + len(lycoris.loaded_lycos)
|
|
|
|
if lora_controller.text_model_encoder_counter == -1:
|
|
lora_controller.text_model_encoder_counter = len(lora_controller.prompt_loras) * num_loras
|
|
|
|
tmp_check_loras = [] #store which lora are already apply
|
|
tmp_check_loras.clear()
|
|
|
|
for m_lycoris in lycoris.loaded_lycos:
|
|
module = m_lycoris.modules.get(lycoris_layer_name, None)
|
|
if module is None:
|
|
#fix the lyCORIS issue
|
|
check_lycoris_end_layer(lycoris_layer_name, res, num_loras)
|
|
continue
|
|
|
|
current_lora = normalize_lora_name(m_lycoris.name)
|
|
lora_already_used = False
|
|
if current_lora in tmp_check_loras:
|
|
lora_already_used = True
|
|
#store the applied lora into list
|
|
tmp_check_loras.append(current_lora)
|
|
if lora_already_used:
|
|
check_lycoris_end_layer(lycoris_layer_name, res, num_loras)
|
|
continue
|
|
|
|
converted_module = convert_lycoris(module, sd_module)
|
|
if converted_module is None:
|
|
check_lycoris_end_layer(lycoris_layer_name, res, num_loras)
|
|
continue
|
|
|
|
patch = get_lora_patch(converted_module, input, res, lycoris_layer_name)
|
|
alpha = get_lora_alpha(converted_module, 1.0)
|
|
num_prompts = len(lora_controller.prompt_loras)
|
|
|
|
# print(f"lora.name={m_lora.name} lora.mul={m_lora.multiplier} alpha={alpha} pat.shape={patch.shape}")
|
|
res = lora_controller.apply_composable_lora(lycoris_layer_name, m_lycoris, converted_module, "lyco", patch, alpha, res, num_loras, num_prompts)
|
|
|
|
return res
|
|
|
|
def composable_forward(module, patch, alpha, multiplier, res):
|
|
if hasattr(module, 'composable_forward'):
|
|
return module.composable_forward(patch, alpha, multiplier, res)
|
|
return res + multiplier * alpha * patch
|
|
|
|
re_lora_block_weight = re.compile(r"[_\s]*added[_\s]*by[_\s]*lora[_\s]*block[_\s]*weight[_\s]*.*$")
|
|
|
|
def normalize_lora_name(lora_name):
|
|
result = re.sub(r"[_\s]*added[_\s]*by[_\s]*lora[_\s]*block[_\s]*weight[_\s]*.*$", "", lora_name)
|
|
return result
|
|
|
|
def get_lora_inference(module, input):
|
|
if hasattr(module, 'inference'): #support for lyCORIS
|
|
return module.inference(input)
|
|
elif hasattr(module, 'up'): #LoRA
|
|
return module.up(module.down(input))
|
|
else:
|
|
return None
|
|
|
|
def get_lora_patch(module, input, res, lora_layer_name):
|
|
if is_loha(module):
|
|
if input.is_cuda: #if is cuda, pass to cuda; otherwise do nothing
|
|
pass_loha_to_gpu(module)
|
|
if getattr(shared.opts, "lora_apply_to_outputs", False) and res.shape == input.shape:
|
|
inference = get_lora_inference(module, res)
|
|
if inference is not None:
|
|
return inference
|
|
else:
|
|
converted_module = convert_lycoris(module, shared.sd_model.lora_layer_mapping.get(lora_layer_name, None))
|
|
if converted_module is not None:
|
|
return get_lora_inference(converted_module, res)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Your settings, extensions or models are not compatible with each other."
|
|
)
|
|
else:
|
|
inference = get_lora_inference(module, input)
|
|
if inference is not None:
|
|
return inference
|
|
else:
|
|
converted_module = convert_lycoris(module, shared.sd_model.lora_layer_mapping.get(lora_layer_name, None))
|
|
if converted_module is not None:
|
|
return get_lora_inference(converted_module, input)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Your settings, extensions or models are not compatible with each other."
|
|
)
|
|
|
|
def get_lora_alpha(module, default_val):
|
|
if hasattr(module, 'up'):
|
|
return (module.alpha / module.up.weight.shape[1] if module.alpha else default_val)
|
|
elif hasattr(module, 'dim'): #support for lyCORIS
|
|
return (module.alpha / module.dim if module.alpha else default_val)
|
|
else:
|
|
return default_val
|
|
|
|
def check_lycoris_end_layer(lora_layer_name: str, res, num_loras):
|
|
if lora_layer_name.endswith("_11_mlp_fc2") or lora_layer_name.endswith("_11_1_proj_out"):
|
|
import composable_lora as lora_controller
|
|
if lora_layer_name.endswith("_11_mlp_fc2"): # lyCORIS maybe doesn't has _11_mlp_fc2 layer
|
|
lora_controller.text_model_encoder_counter += 1
|
|
if lora_controller.text_model_encoder_counter == (len(lora_controller.prompt_loras) + lora_controller.num_batches) * num_loras:
|
|
lora_controller.text_model_encoder_counter = 0
|
|
if lora_layer_name.endswith("_11_1_proj_out"): # lyCORIS maybe doesn't has _11_1_proj_out layer
|
|
lora_controller.diffusion_model_counter += res.shape[0]
|
|
if lora_controller.diffusion_model_counter >= (len(lora_controller.prompt_loras) + lora_controller.num_batches) * num_loras:
|
|
lora_controller.diffusion_model_counter = 0
|
|
lora_controller.add_step_counters()
|
|
|
|
def lycoris_get_multiplier(lycoris_model, lora_layer_name):
|
|
multiplier = 1.0
|
|
if hasattr(lycoris_model, 'te_multiplier'):
|
|
multiplier = (
|
|
lycoris_model.te_multiplier if 'transformer' in lora_layer_name[:20]
|
|
else lycoris_model.unet_multiplier
|
|
)
|
|
elif hasattr(lycoris_model, 'multiplier'):
|
|
multiplier = getattr(lycoris_model, 'multiplier', 1.0)
|
|
return multiplier
|
|
|
|
def lycoris_get_multiplier_normalized(lycoris_model, lora_layer_name):
|
|
multiplier = 1.0
|
|
if hasattr(lycoris_model, 'te_multiplier'):
|
|
te_multiplier = 1.0
|
|
unet_multiplier = lycoris_model.unet_multiplier / lycoris_model.te_multiplier
|
|
multiplier = (
|
|
te_multiplier if 'transformer' in lora_layer_name[:20]
|
|
else unet_multiplier
|
|
)
|
|
return multiplier
|
|
|
|
class FakeModule(torch.nn.Module):
|
|
def __init__(self, weight, func):
|
|
super().__init__()
|
|
self.weight = weight
|
|
self.func = func
|
|
|
|
def forward(self, x):
|
|
return self.func(x)
|
|
|
|
class FullModule:
|
|
def __init__(self):
|
|
self.weight = None
|
|
self.alpha = None
|
|
self.op = None
|
|
self.extra_args = {}
|
|
self.shape = None
|
|
self.up = None
|
|
|
|
def down(self, x):
|
|
return x
|
|
|
|
def inference(self, x):
|
|
return self.op(x, self.weight, **self.extra_args)
|
|
|
|
class IA3Module:
|
|
def __init__(self):
|
|
self.w = None
|
|
self.alpha = None
|
|
self.on_input = None
|
|
self.shape = None
|
|
self.op = None
|
|
self.extra_args = {}
|
|
|
|
def down(self, x):
|
|
return x
|
|
|
|
def inference(self, x):
|
|
return self.op(x, self.w, **self.extra_args)
|
|
|
|
def composable_forward(self, patch, alpha, multiplier, res):
|
|
patch = patch.to(res.dtype)
|
|
return res * (1 + patch * alpha * multiplier)
|
|
|
|
class LoraUpDownModule:
|
|
def __init__(self):
|
|
self.up_model = None
|
|
self.mid_model = None
|
|
self.down_model = None
|
|
self.alpha = None
|
|
self.dim = None
|
|
self.op = None
|
|
self.extra_args = {}
|
|
self.shape = None
|
|
self.bias = None
|
|
self.up = None
|
|
|
|
def down(self, x):
|
|
return x
|
|
|
|
def inference(self, x):
|
|
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
|
|
out_dim = self.up_model.weight.size(0)
|
|
rank = self.down_model.weight.size(0)
|
|
rebuild_weight = (
|
|
self.up_model.weight.reshape(out_dim, -1) @ self.down_model.weight.reshape(rank, -1)
|
|
+ self.bias
|
|
).reshape(self.shape)
|
|
return self.op(
|
|
x, rebuild_weight,
|
|
bias=None,
|
|
**self.extra_args
|
|
)
|
|
else:
|
|
if self.mid_model is None:
|
|
return self.up_model(self.down_model(x))
|
|
else:
|
|
return self.up_model(self.mid_model(self.down_model(x)))
|
|
|
|
def make_weight_cp(t, wa, wb):
|
|
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
|
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
|
|
|
class LoraHadaModule:
|
|
def __init__(self):
|
|
self.t1 = None
|
|
self.w1a = None
|
|
self.w1b = None
|
|
self.t2 = None
|
|
self.w2a = None
|
|
self.w2b = None
|
|
self.alpha = None
|
|
self.dim = None
|
|
self.op = None
|
|
self.extra_args = {}
|
|
self.shape = None
|
|
self.bias = None
|
|
self.up = None
|
|
|
|
def down(self, x):
|
|
return x
|
|
|
|
def inference(self, x):
|
|
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
|
|
bias = self.bias
|
|
else:
|
|
bias = 0
|
|
|
|
if self.t1 is None:
|
|
return self.op(
|
|
x,
|
|
((self.w1a @ self.w1b) * (self.w2a @ self.w2b) + bias).view(self.shape),
|
|
bias=None,
|
|
**self.extra_args
|
|
)
|
|
else:
|
|
return self.op(
|
|
x,
|
|
(make_weight_cp(self.t1, self.w1a, self.w1b)
|
|
* make_weight_cp(self.t2, self.w2a, self.w2b) + bias).view(self.shape),
|
|
bias=None,
|
|
**self.extra_args
|
|
)
|
|
|
|
def make_kron(orig_shape, w1, w2):
|
|
if len(w2.shape) == 4:
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
w2 = w2.contiguous()
|
|
return torch.kron(w1, w2).reshape(orig_shape)
|
|
|
|
class LoraKronModule:
|
|
def __init__(self):
|
|
self.w1 = None
|
|
self.w1a = None
|
|
self.w1b = None
|
|
self.w2 = None
|
|
self.t2 = None
|
|
self.w2a = None
|
|
self.w2b = None
|
|
self._alpha = None
|
|
self.dim = None
|
|
self.op = None
|
|
self.extra_args = {}
|
|
self.shape = None
|
|
self.bias = None
|
|
self.up = None
|
|
|
|
@property
|
|
def alpha(self):
|
|
if self.w1a is None and self.w2a is None:
|
|
return None
|
|
else:
|
|
return self._alpha
|
|
|
|
@alpha.setter
|
|
def alpha(self, x):
|
|
self._alpha = x
|
|
|
|
def down(self, x):
|
|
return x
|
|
|
|
def inference(self, x):
|
|
if hasattr(self, 'bias') and isinstance(self.bias, torch.Tensor):
|
|
bias = self.bias
|
|
else:
|
|
bias = 0
|
|
|
|
if self.t2 is None:
|
|
return self.op(
|
|
x,
|
|
(torch.kron(self.w1, self.w2a@self.w2b) + bias).view(self.shape),
|
|
**self.extra_args
|
|
)
|
|
else:
|
|
# will raise NotImplemented Error
|
|
return self.op(
|
|
x,
|
|
(torch.kron(self.w1, make_weight_cp(self.t2, self.w2a, self.w2b)) + bias).view(self.shape),
|
|
**self.extra_args
|
|
)
|
|
|
|
def convert_lycoris(lycoris_module, sd_module):
|
|
result_module = getattr(lycoris_module, 'lyco_converted_lora_module', None)
|
|
if result_module is not None:
|
|
return result_module
|
|
if lycoris_module.__class__.__name__ == "LycoUpDownModule" or lycoris_module.__class__.__name__ == "LoraUpDownModule":
|
|
result_module = LoraUpDownModule()
|
|
if (type(sd_module) == torch.nn.Linear
|
|
or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear
|
|
or type(sd_module) == torch.nn.MultiheadAttention):
|
|
result_module.op = torch.nn.functional.linear
|
|
elif type(sd_module) == torch.nn.Conv2d:
|
|
result_module.op = torch.nn.functional.conv2d
|
|
result_module.extra_args = {
|
|
'stride': sd_module.stride,
|
|
'padding': sd_module.padding
|
|
}
|
|
else:
|
|
return None
|
|
result_module.up_model = lycoris_module.up_model
|
|
result_module.mid_model = lycoris_module.mid_model
|
|
result_module.down_model = lycoris_module.down_model
|
|
result_module.alpha = lycoris_module.alpha
|
|
result_module.dim = lycoris_module.dim
|
|
result_module.shape = lycoris_module.shape
|
|
result_module.bias = lycoris_module.bias
|
|
result_module.up = FakeModule(
|
|
result_module.up_model.weight,
|
|
result_module.inference
|
|
)
|
|
elif lycoris_module.__class__.__name__ == "FullModule":
|
|
result_module = FullModule()
|
|
result_module.weight = lycoris_module.weight#.to(device=devices.device, dtype=devices.dtype)
|
|
result_module.alpha = lycoris_module.alpha
|
|
result_module.shape = lycoris_module.shape
|
|
result_module.up = FakeModule(
|
|
result_module.weight,
|
|
result_module.inference
|
|
)
|
|
if len(result_module.weight.shape)==2:
|
|
result_module.op = torch.nn.functional.linear
|
|
result_module.extra_args = {
|
|
'bias': None
|
|
}
|
|
else:
|
|
result_module.op = torch.nn.functional.conv2d
|
|
result_module.extra_args = {
|
|
'stride': sd_module.stride,
|
|
'padding': sd_module.padding,
|
|
'bias': None
|
|
}
|
|
setattr(lycoris_module, "lyco_converted_lora_module", result_module)
|
|
return result_module
|
|
elif lycoris_module.__class__.__name__ == "IA3Module":
|
|
result_module = IA3Module()
|
|
result_module.w = lycoris_module.w
|
|
result_module.alpha = lycoris_module.alpha
|
|
result_module.on_input = lycoris_module.on_input
|
|
if hasattr(sd_module, 'weight'):
|
|
result_module.shape = sd_module.weight.shape
|
|
if (type(sd_module) == torch.nn.Linear
|
|
or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear
|
|
or type(sd_module) == torch.nn.MultiheadAttention):
|
|
result_module.op = torch.nn.functional.linear
|
|
elif type(sd_module) == torch.nn.Conv2d:
|
|
result_module.op = torch.nn.functional.conv2d
|
|
elif lycoris_module.__class__.__name__ == "LycoHadaModule" or lycoris_module.__class__.__name__ == "LoraHadaModule":
|
|
result_module = LoraHadaModule()
|
|
result_module.t1 = lycoris_module.t1
|
|
result_module.w1a = lycoris_module.w1a
|
|
result_module.w1b = lycoris_module.w1b
|
|
result_module.t2 = lycoris_module.t2
|
|
result_module.w2a = lycoris_module.w2a
|
|
result_module.w2b = lycoris_module.w2b
|
|
result_module.alpha = lycoris_module.alpha
|
|
result_module.dim = lycoris_module.dim
|
|
result_module.shape = lycoris_module.shape
|
|
result_module.bias = lycoris_module.bias
|
|
result_module.up = FakeModule(
|
|
result_module.t1 if result_module.t1 is not None else result_module.w1a,
|
|
result_module.inference
|
|
)
|
|
if (type(sd_module) == torch.nn.Linear
|
|
or type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear
|
|
or type(sd_module) == torch.nn.MultiheadAttention):
|
|
result_module.op = torch.nn.functional.linear
|
|
elif type(sd_module) == torch.nn.Conv2d:
|
|
result_module.op = torch.nn.functional.conv2d
|
|
result_module.extra_args = {
|
|
'stride': sd_module.stride,
|
|
'padding': sd_module.padding
|
|
}
|
|
elif lycoris_module.__class__.__name__ == "LycoKronModule" or lycoris_module.__class__.__name__ == "LoraKronModule" :
|
|
result_module = LoraKronModule()
|
|
result_module.w1 = lycoris_module.w1
|
|
result_module.w1a = lycoris_module.w1a
|
|
result_module.w1b = lycoris_module.w1b
|
|
result_module.w2 = lycoris_module.w2
|
|
result_module.t2 = lycoris_module.t2
|
|
result_module.w2a = lycoris_module.w2a
|
|
result_module.w2b = lycoris_module.w2b
|
|
result_module._alpha = lycoris_module._alpha
|
|
result_module.dim = lycoris_module.dim
|
|
result_module.shape = lycoris_module.shape
|
|
result_module.bias = lycoris_module.bias
|
|
result_module.up = FakeModule(
|
|
result_module.w1a if result_module.w1a is not None else result_module.w2a,
|
|
result_module.inference
|
|
)
|
|
if (any(isinstance(sd_module, torch_layer) for torch_layer in
|
|
[torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention])):
|
|
result_module.op = torch.nn.functional.linear
|
|
elif isinstance(sd_module, torch.nn.Conv2d):
|
|
result_module.op = torch.nn.functional.conv2d
|
|
result_module.extra_args = {
|
|
'stride': sd_module.stride,
|
|
'padding': sd_module.padding
|
|
}
|
|
if result_module is not None:
|
|
setattr(lycoris_module, "lyco_converted_lora_module", result_module)
|
|
return result_module
|
|
return None
|
|
|
|
def is_loha(m_lora):
|
|
return hasattr(m_lora, 'w1a') or hasattr(m_lora, 'w1b') or hasattr(m_lora, 'w2a') or hasattr(m_lora, 'w2b')
|
|
|
|
def pass_loha_to_gpu(m_loha):
|
|
if hasattr(m_loha, 'bias'):
|
|
if isinstance(m_loha.bias, torch.Tensor):
|
|
if not m_loha.bias.is_cuda:
|
|
to_cuda = m_loha.bias.cuda()
|
|
to_del = m_loha.bias
|
|
m_loha.bias = None
|
|
del to_del
|
|
del m_loha.bias
|
|
m_loha.bias = to_cuda
|
|
if hasattr(m_loha, 't1'):
|
|
if isinstance(m_loha.t1, torch.Tensor):
|
|
if not m_loha.t1.is_cuda:
|
|
to_cuda = m_loha.t1.cuda()
|
|
to_del = m_loha.t1
|
|
m_loha.t1 = None
|
|
del to_del
|
|
del m_loha.t1
|
|
m_loha.t1 = to_cuda
|
|
if hasattr(m_loha, 't2'):
|
|
if isinstance(m_loha.t2, torch.Tensor):
|
|
if not m_loha.t2.is_cuda:
|
|
to_cuda = m_loha.t2.cuda()
|
|
to_del = m_loha.t2
|
|
m_loha.t2 = None
|
|
del to_del
|
|
del m_loha.t2
|
|
m_loha.t2 = to_cuda
|
|
if hasattr(m_loha, 'w'):
|
|
if isinstance(m_loha.w, torch.Tensor):
|
|
if not m_loha.w.is_cuda:
|
|
to_cuda = m_loha.w.cuda()
|
|
to_del = m_loha.w
|
|
m_loha.w = None
|
|
del to_del
|
|
del m_loha.w
|
|
m_loha.w = to_cuda
|
|
if hasattr(m_loha, 'w1'):
|
|
if isinstance(m_loha.w1, torch.Tensor):
|
|
if not m_loha.w1.is_cuda:
|
|
to_cuda = m_loha.w1.cuda()
|
|
to_del = m_loha.w1
|
|
m_loha.w1 = None
|
|
del to_del
|
|
del m_loha.w1
|
|
m_loha.w1 = to_cuda
|
|
if hasattr(m_loha, 'w1a'):
|
|
if isinstance(m_loha.w1a, torch.Tensor):
|
|
if not m_loha.w1a.is_cuda:
|
|
to_cuda = m_loha.w1a.cuda()
|
|
to_del = m_loha.w1a
|
|
m_loha.w1a = None
|
|
del to_del
|
|
del m_loha.w1a
|
|
m_loha.w1a = to_cuda
|
|
if hasattr(m_loha, 'w1b'):
|
|
if isinstance(m_loha.w1b, torch.Tensor):
|
|
if not m_loha.w1b.is_cuda:
|
|
to_cuda = m_loha.w1b.cuda()
|
|
to_del = m_loha.w1b
|
|
m_loha.w1b = None
|
|
del to_del
|
|
del m_loha.w1b
|
|
m_loha.w1b = to_cuda
|
|
if hasattr(m_loha, 'w2'):
|
|
if isinstance(m_loha.w2, torch.Tensor):
|
|
if not m_loha.w2.is_cuda:
|
|
to_cuda = m_loha.w2.cuda()
|
|
to_del = m_loha.w2
|
|
m_loha.w2 = None
|
|
del to_del
|
|
del m_loha.w2
|
|
m_loha.w2 = to_cuda
|
|
if hasattr(m_loha, 'w2a'):
|
|
if isinstance(m_loha.w2a, torch.Tensor):
|
|
if not m_loha.w2a.is_cuda:
|
|
to_cuda = m_loha.w2a.cuda()
|
|
to_del = m_loha.w2a
|
|
m_loha.w2a = None
|
|
del to_del
|
|
del m_loha.w2a
|
|
m_loha.w2a = to_cuda
|
|
if hasattr(m_loha, 'w2b'):
|
|
if isinstance(m_loha.w2b, torch.Tensor):
|
|
if not m_loha.w2b.is_cuda:
|
|
to_cuda = m_loha.w2b.cuda()
|
|
to_del = m_loha.w2b
|
|
m_loha.w2b = None
|
|
del to_del
|
|
del m_loha.w2b
|
|
m_loha.w2b = to_cuda
|
|
|
|
has_webui_lycoris : bool = False |