resove LoCon poblem
parent
d4963e487c
commit
db2ebb9eb2
|
|
@ -49,13 +49,39 @@ def lora_forward(compvis_module, input, res):
|
|||
if text_model_encoder_counter == -1:
|
||||
text_model_encoder_counter = len(prompt_loras) * num_loras
|
||||
|
||||
# print(f"lora.forward lora_layer_name={lora_layer_name} in.shape={input.shape} res.shape={res.shape} num_batches={num_batches} num_prompts={num_prompts}")
|
||||
tmp_check_loras = [] #store which lora are already apply
|
||||
tmp_check_loras.clear()
|
||||
|
||||
# print(f"lora.forward lora_layer_name={lora_layer_name} in.shape={input.shape} res.shape={res.shape} num_batches={num_batches} num_prompts={num_prompts}")
|
||||
for lora in lora.loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is None:
|
||||
#fix the loCon issue
|
||||
if lora_layer_name.endswith("_11_mlp_fc2"): # locon doesn't has _11_mlp_fc2 layer
|
||||
text_model_encoder_counter += 1
|
||||
# c1 c1 c2 c2 .. .. uc uc
|
||||
if text_model_encoder_counter == (len(prompt_loras) + num_batches) * num_loras:
|
||||
text_model_encoder_counter = 0
|
||||
if lora_layer_name.endswith("_11_1_proj_out"): # locon doesn't has _11_1_proj_out layer
|
||||
diffusion_model_counter += res.shape[0]
|
||||
# c1 c2 .. uc
|
||||
if diffusion_model_counter >= (len(prompt_loras) + num_batches) * num_loras:
|
||||
diffusion_model_counter = 0
|
||||
continue
|
||||
|
||||
|
||||
current_lora = lora.name
|
||||
lora_already_used = False
|
||||
for check_lora in tmp_check_loras:
|
||||
if current_lora == check_lora:
|
||||
#find the same lora, marked
|
||||
lora_already_used = True
|
||||
break
|
||||
#store the applied lora into list
|
||||
tmp_check_loras.append(current_lora)
|
||||
#if current lora already apply, skip this lora
|
||||
if lora_already_used == True:
|
||||
continue
|
||||
|
||||
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
|
||||
patch = module.up(module.down(res))
|
||||
else:
|
||||
|
|
@ -66,7 +92,6 @@ def lora_forward(compvis_module, input, res):
|
|||
num_prompts = len(prompt_loras)
|
||||
|
||||
# print(f"lora.name={lora.name} lora.mul={lora.multiplier} alpha={alpha} pat.shape={patch.shape}")
|
||||
|
||||
if enabled:
|
||||
if lora_layer_name.startswith("transformer_"): # "transformer_text_model_encoder_"
|
||||
#
|
||||
|
|
@ -113,14 +138,16 @@ def lora_forward(compvis_module, input, res):
|
|||
# tensor.shape[1] != uncond.shape[1]
|
||||
cur_num_prompts = res.shape[0]
|
||||
base = (diffusion_model_counter // cur_num_prompts) // num_loras * cur_num_prompts
|
||||
prompt_len = len(prompt_loras)
|
||||
if 0 <= base < len(prompt_loras):
|
||||
# c
|
||||
for off in range(cur_num_prompts):
|
||||
loras = prompt_loras[base + off]
|
||||
multiplier = loras.get(lora.name, 0.0)
|
||||
if multiplier != 0.0:
|
||||
# print(f"c #{base + off} lora.name={lora.name} mul={multiplier}", lora_layer_name=lora_layer_name)
|
||||
res[off] += multiplier * alpha * patch[off]
|
||||
if base + off < prompt_len:
|
||||
loras = prompt_loras[base + off]
|
||||
multiplier = loras.get(lora.name, 0.0)
|
||||
if multiplier != 0.0:
|
||||
# print(f"c #{base + off} lora.name={lora.name} mul={multiplier}", lora_layer_name=lora_layer_name)
|
||||
res[off] += multiplier * alpha * patch[off]
|
||||
else:
|
||||
# uc
|
||||
if opt_uc_diffusion_model and lora.multiplier != 0.0:
|
||||
|
|
@ -142,7 +169,7 @@ def lora_forward(compvis_module, input, res):
|
|||
if lora.multiplier != 0.0:
|
||||
# print(f"DEFAULT {lora_layer_name} lora.name={lora.name} lora.mul={lora.multiplier}")
|
||||
res += lora.multiplier * alpha * patch
|
||||
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue