diff --git a/scripts/mergers/pluslora.py b/scripts/mergers/pluslora.py index 29b34e2..17eeb15 100644 --- a/scripts/mergers/pluslora.py +++ b/scripts/mergers/pluslora.py @@ -293,14 +293,14 @@ def pluslora(lnames,loraratios,settings,output,model,precision): # print(f"apply {key} to {module}") - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] + down_weight = lora_sd[key].to(device="cpu") + up_weight = lora_sd[up_key].to(device="cpu") dim = down_weight.size()[0] alpha = lora_sd.get(alpha_key, dim) scale = alpha / dim # W <- W + U * D - weight = theta_0[keychanger[msd_key]] + weight = theta_0[keychanger[msd_key]].to(device="cpu") if not len(down_weight.size()) == 4: # linear