commit
1114302bdc
|
|
@ -544,97 +544,121 @@ def merge_lora_models(models, ratios, sets, locon, calc_precision, device):
|
|||
return merged_sd
|
||||
|
||||
def merge_lora_models_dim(models, ratios, new_rank, sets, device, calc_precision):
|
||||
merged_sd = {}
|
||||
fugou = 1
|
||||
CHUNK_SIZE = 50
|
||||
|
||||
isv2 = False
|
||||
merge_dtype = str_to_dtype(calc_precision)
|
||||
for model, ratios in zip(models, ratios):
|
||||
|
||||
lora_sd, medadata, isv2 = load_state_dict(model, merge_dtype, device)
|
||||
|
||||
# merge
|
||||
print(f"merging {model}: {ratios}")
|
||||
for key in tqdm(list(lora_sd.keys())):
|
||||
if 'lora_down' not in key:
|
||||
continue
|
||||
lora_module_name = key[:key.rfind(".lora_down")]
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
network_dim = down_weight.size()[0]
|
||||
|
||||
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
|
||||
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
|
||||
|
||||
in_dim = down_weight.size()[1]
|
||||
out_dim = up_weight.size()[0]
|
||||
conv2d = len(down_weight.size()) == 4
|
||||
# print(lora_module_name, network_dim, alpha, in_dim, out_dim)
|
||||
|
||||
# make original weight if not exist
|
||||
if lora_module_name not in merged_sd:
|
||||
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype, device=device)
|
||||
else:
|
||||
weight = merged_sd[lora_module_name]
|
||||
|
||||
ratio = ratios[blockfromkey(key, LBLCOKS26,isv2)]
|
||||
if "same to Strength" in sets:
|
||||
ratio, fugou = (ratio ** 0.5, 1) if ratio > 0 else (abs(ratio) ** 0.5, -1)
|
||||
# print(lora_module_name, ratio)
|
||||
# W <- W + U * D
|
||||
scale = (alpha / network_dim)
|
||||
if not conv2d: # linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale * fugou
|
||||
else:
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale * fugou
|
||||
|
||||
merged_sd[lora_module_name] = weight
|
||||
|
||||
lora_sd = None
|
||||
del lora_sd
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for key in merged_sd.keys():
|
||||
merged_sd[key] = merged_sd[key].to(torch.float)
|
||||
lora_sds = []
|
||||
print("Loading LoRA models...")
|
||||
for model in models:
|
||||
lora_sd, _, _isv2 = load_state_dict(model, merge_dtype, "cpu")
|
||||
isv2 = isv2 or _isv2
|
||||
lora_sds.append(lora_sd)
|
||||
|
||||
all_lora_module_names = set()
|
||||
for lora_sd in lora_sds:
|
||||
for key in lora_sd.keys():
|
||||
if 'lora_down' in key:
|
||||
lora_module_name = key[:key.rfind(".lora_down")]
|
||||
all_lora_module_names.add(lora_module_name)
|
||||
|
||||
all_lora_module_names = sorted(list(all_lora_module_names))
|
||||
total_modules = len(all_lora_module_names)
|
||||
total_chunks = (total_modules + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||
print(f"Found {total_modules} unique modules to merge. Processing in {total_chunks} chunks of {CHUNK_SIZE}.")
|
||||
|
||||
# extract from merged weights
|
||||
print("extract new lora...")
|
||||
merged_lora_sd = {}
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||
conv2d = (len(mat.size()) == 4)
|
||||
if conv2d:
|
||||
mat = mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
with tqdm(total=total_modules, desc="Overall Progress") as pbar_overall:
|
||||
for i in range(0, total_modules, CHUNK_SIZE):
|
||||
chunk = all_lora_module_names[i:i + CHUNK_SIZE]
|
||||
|
||||
pbar_overall.set_description(f"Processing Chunk {i//CHUNK_SIZE + 1}/{total_chunks}")
|
||||
|
||||
U = U[:, :new_rank]
|
||||
S = S[:new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
merged_sd_chunk = {}
|
||||
original_shapes_chunk = {}
|
||||
|
||||
Vh = Vh[:new_rank, :]
|
||||
for lora_module_name in chunk:
|
||||
merged_weight = None
|
||||
|
||||
for j, lora_sd in enumerate(lora_sds):
|
||||
ratio = ratios[j]
|
||||
down_key = lora_module_name + '.lora_down.weight'
|
||||
if down_key not in lora_sd:
|
||||
continue
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
down_weight = lora_sd[down_key].to(device, non_blocking=True)
|
||||
up_weight = lora_sd[lora_module_name + '.lora_up.weight'].to(device, non_blocking=True)
|
||||
|
||||
network_dim = down_weight.size(0)
|
||||
alpha = lora_sd.get(lora_module_name + '.alpha', torch.tensor(network_dim)).to(device, non_blocking=True)
|
||||
scale = (alpha / network_dim) if network_dim else 0
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
conv2d = len(down_weight.size()) == 4
|
||||
if not conv2d:
|
||||
diff = (up_weight @ down_weight)
|
||||
else:
|
||||
diff = torch.nn.functional.conv2d(
|
||||
down_weight.permute(1, 0, 2, 3), up_weight
|
||||
).permute(1, 0, 2, 3)
|
||||
|
||||
up_weight = U
|
||||
down_weight = Vh
|
||||
block_ratio = ratio[blockfromkey(down_key, LBLCOKS26, isv2)]
|
||||
fugou = 1
|
||||
if "same to Strength" in sets:
|
||||
block_ratio, fugou = (block_ratio ** 0.5, 1) if block_ratio > 0 else (abs(block_ratio) ** 0.5, -1)
|
||||
|
||||
if merged_weight is None:
|
||||
merged_weight = (block_ratio * diff * scale * fugou)
|
||||
else:
|
||||
merged_weight += (block_ratio * diff * scale * fugou)
|
||||
|
||||
if conv2d:
|
||||
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
||||
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
||||
if merged_weight is not None:
|
||||
merged_sd_chunk[lora_module_name] = merged_weight
|
||||
if len(merged_weight.shape) == 4:
|
||||
original_shapes_chunk[lora_module_name] = merged_weight.shape
|
||||
|
||||
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in merged_sd_chunk.items():
|
||||
mat = mat.to(torch.float)
|
||||
|
||||
conv2d = lora_module_name in original_shapes_chunk
|
||||
if conv2d:
|
||||
out_dim, in_dim, k_h, k_w = original_shapes_chunk[lora_module_name]
|
||||
mat = mat.reshape(out_dim, -1)
|
||||
|
||||
del merged_sd
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
U = U[:, :new_rank]
|
||||
S = S[:new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:new_rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
new_up_weight = U
|
||||
new_down_weight = Vh
|
||||
|
||||
if conv2d:
|
||||
out_dim, in_dim, k_h, k_w = original_shapes_chunk[lora_module_name]
|
||||
new_up_weight = new_up_weight.unsqueeze(2).unsqueeze(3)
|
||||
new_down_weight = new_down_weight.view(new_rank, in_dim, k_h, k_w)
|
||||
|
||||
merged_lora_sd[lora_module_name + '.lora_up.weight'] = new_up_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.lora_down.weight'] = new_down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(float(new_rank))
|
||||
|
||||
pbar_overall.update(1)
|
||||
|
||||
del merged_sd_chunk, original_shapes_chunk, chunk
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print("LoRA merge process completed.")
|
||||
return merged_lora_sd
|
||||
|
||||
def extract_two(a,b,pa,pb,ps):
|
||||
|
|
|
|||
Loading…
Reference in New Issue