From db2ebb9eb28c46e2ba8f1f84c734e5135833d8c5 Mon Sep 17 00:00:00 2001 From: a2569875 Date: Thu, 13 Apr 2023 18:33:40 +0800 Subject: [PATCH] resove LoCon poblem --- composable_lora.py | 45 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/composable_lora.py b/composable_lora.py index 6258df1..ef0c506 100644 --- a/composable_lora.py +++ b/composable_lora.py @@ -49,13 +49,39 @@ def lora_forward(compvis_module, input, res): if text_model_encoder_counter == -1: text_model_encoder_counter = len(prompt_loras) * num_loras - # print(f"lora.forward lora_layer_name={lora_layer_name} in.shape={input.shape} res.shape={res.shape} num_batches={num_batches} num_prompts={num_prompts}") + tmp_check_loras = [] #store which lora are already apply + tmp_check_loras.clear() + # print(f"lora.forward lora_layer_name={lora_layer_name} in.shape={input.shape} res.shape={res.shape} num_batches={num_batches} num_prompts={num_prompts}") for lora in lora.loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is None: + #fix the loCon issue + if lora_layer_name.endswith("_11_mlp_fc2"): # locon doesn't has _11_mlp_fc2 layer + text_model_encoder_counter += 1 + # c1 c1 c2 c2 .. .. uc uc + if text_model_encoder_counter == (len(prompt_loras) + num_batches) * num_loras: + text_model_encoder_counter = 0 + if lora_layer_name.endswith("_11_1_proj_out"): # locon doesn't has _11_1_proj_out layer + diffusion_model_counter += res.shape[0] + # c1 c2 .. uc + if diffusion_model_counter >= (len(prompt_loras) + num_batches) * num_loras: + diffusion_model_counter = 0 continue - + + current_lora = lora.name + lora_already_used = False + for check_lora in tmp_check_loras: + if current_lora == check_lora: + #find the same lora, marked + lora_already_used = True + break + #store the applied lora into list + tmp_check_loras.append(current_lora) + #if current lora already apply, skip this lora + if lora_already_used == True: + continue + if shared.opts.lora_apply_to_outputs and res.shape == input.shape: patch = module.up(module.down(res)) else: @@ -66,7 +92,6 @@ def lora_forward(compvis_module, input, res): num_prompts = len(prompt_loras) # print(f"lora.name={lora.name} lora.mul={lora.multiplier} alpha={alpha} pat.shape={patch.shape}") - if enabled: if lora_layer_name.startswith("transformer_"): # "transformer_text_model_encoder_" # @@ -113,14 +138,16 @@ def lora_forward(compvis_module, input, res): # tensor.shape[1] != uncond.shape[1] cur_num_prompts = res.shape[0] base = (diffusion_model_counter // cur_num_prompts) // num_loras * cur_num_prompts + prompt_len = len(prompt_loras) if 0 <= base < len(prompt_loras): # c for off in range(cur_num_prompts): - loras = prompt_loras[base + off] - multiplier = loras.get(lora.name, 0.0) - if multiplier != 0.0: - # print(f"c #{base + off} lora.name={lora.name} mul={multiplier}", lora_layer_name=lora_layer_name) - res[off] += multiplier * alpha * patch[off] + if base + off < prompt_len: + loras = prompt_loras[base + off] + multiplier = loras.get(lora.name, 0.0) + if multiplier != 0.0: + # print(f"c #{base + off} lora.name={lora.name} mul={multiplier}", lora_layer_name=lora_layer_name) + res[off] += multiplier * alpha * patch[off] else: # uc if opt_uc_diffusion_model and lora.multiplier != 0.0: @@ -142,7 +169,7 @@ def lora_forward(compvis_module, input, res): if lora.multiplier != 0.0: # print(f"DEFAULT {lora_layer_name} lora.name={lora.name} lora.mul={lora.multiplier}") res += lora.multiplier * alpha * patch - + return res