Merge pull request #464 from tukisuwa/main

SVDマージ時の処理方法変更。
ver22
hako-mikan 2025-10-07 22:00:13 +09:00 committed by GitHub
commit 1114302bdc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 100 additions and 76 deletions

View File

@ -544,97 +544,121 @@ def merge_lora_models(models, ratios, sets, locon, calc_precision, device):
return merged_sd
def merge_lora_models_dim(models, ratios, new_rank, sets, device, calc_precision):
merged_sd = {}
fugou = 1
CHUNK_SIZE = 50
isv2 = False
merge_dtype = str_to_dtype(calc_precision)
for model, ratios in zip(models, ratios):
lora_sd, medadata, isv2 = load_state_dict(model, merge_dtype, device)
# merge
print(f"merging {model}: {ratios}")
for key in tqdm(list(lora_sd.keys())):
if 'lora_down' not in key:
continue
lora_module_name = key[:key.rfind(".lora_down")]
down_weight = lora_sd[key]
network_dim = down_weight.size()[0]
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
in_dim = down_weight.size()[1]
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
# print(lora_module_name, network_dim, alpha, in_dim, out_dim)
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype, device=device)
else:
weight = merged_sd[lora_module_name]
ratio = ratios[blockfromkey(key, LBLCOKS26,isv2)]
if "same to Strength" in sets:
ratio, fugou = (ratio ** 0.5, 1) if ratio > 0 else (abs(ratio) ** 0.5, -1)
# print(lora_module_name, ratio)
# W <- W + U * D
scale = (alpha / network_dim)
if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale * fugou
else:
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale * fugou
merged_sd[lora_module_name] = weight
lora_sd = None
del lora_sd
torch.cuda.empty_cache()
for key in merged_sd.keys():
merged_sd[key] = merged_sd[key].to(torch.float)
lora_sds = []
print("Loading LoRA models...")
for model in models:
lora_sd, _, _isv2 = load_state_dict(model, merge_dtype, "cpu")
isv2 = isv2 or _isv2
lora_sds.append(lora_sd)
all_lora_module_names = set()
for lora_sd in lora_sds:
for key in lora_sd.keys():
if 'lora_down' in key:
lora_module_name = key[:key.rfind(".lora_down")]
all_lora_module_names.add(lora_module_name)
all_lora_module_names = sorted(list(all_lora_module_names))
total_modules = len(all_lora_module_names)
total_chunks = (total_modules + CHUNK_SIZE - 1) // CHUNK_SIZE
print(f"Found {total_modules} unique modules to merge. Processing in {total_chunks} chunks of {CHUNK_SIZE}.")
# extract from merged weights
print("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
conv2d = (len(mat.size()) == 4)
if conv2d:
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
with tqdm(total=total_modules, desc="Overall Progress") as pbar_overall:
for i in range(0, total_modules, CHUNK_SIZE):
chunk = all_lora_module_names[i:i + CHUNK_SIZE]
pbar_overall.set_description(f"Processing Chunk {i//CHUNK_SIZE + 1}/{total_chunks}")
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
merged_sd_chunk = {}
original_shapes_chunk = {}
Vh = Vh[:new_rank, :]
for lora_module_name in chunk:
merged_weight = None
for j, lora_sd in enumerate(lora_sds):
ratio = ratios[j]
down_key = lora_module_name + '.lora_down.weight'
if down_key not in lora_sd:
continue
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
down_weight = lora_sd[down_key].to(device, non_blocking=True)
up_weight = lora_sd[lora_module_name + '.lora_up.weight'].to(device, non_blocking=True)
network_dim = down_weight.size(0)
alpha = lora_sd.get(lora_module_name + '.alpha', torch.tensor(network_dim)).to(device, non_blocking=True)
scale = (alpha / network_dim) if network_dim else 0
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
conv2d = len(down_weight.size()) == 4
if not conv2d:
diff = (up_weight @ down_weight)
else:
diff = torch.nn.functional.conv2d(
down_weight.permute(1, 0, 2, 3), up_weight
).permute(1, 0, 2, 3)
up_weight = U
down_weight = Vh
block_ratio = ratio[blockfromkey(down_key, LBLCOKS26, isv2)]
fugou = 1
if "same to Strength" in sets:
block_ratio, fugou = (block_ratio ** 0.5, 1) if block_ratio > 0 else (abs(block_ratio) ** 0.5, -1)
if merged_weight is None:
merged_weight = (block_ratio * diff * scale * fugou)
else:
merged_weight += (block_ratio * diff * scale * fugou)
if conv2d:
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
if merged_weight is not None:
merged_sd_chunk[lora_module_name] = merged_weight
if len(merged_weight.shape) == 4:
original_shapes_chunk[lora_module_name] = merged_weight.shape
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
with torch.no_grad():
for lora_module_name, mat in merged_sd_chunk.items():
mat = mat.to(torch.float)
conv2d = lora_module_name in original_shapes_chunk
if conv2d:
out_dim, in_dim, k_h, k_w = original_shapes_chunk[lora_module_name]
mat = mat.reshape(out_dim, -1)
del merged_sd
gc.collect()
torch.cuda.empty_cache()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
new_up_weight = U
new_down_weight = Vh
if conv2d:
out_dim, in_dim, k_h, k_w = original_shapes_chunk[lora_module_name]
new_up_weight = new_up_weight.unsqueeze(2).unsqueeze(3)
new_down_weight = new_down_weight.view(new_rank, in_dim, k_h, k_w)
merged_lora_sd[lora_module_name + '.lora_up.weight'] = new_up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.lora_down.weight'] = new_down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(float(new_rank))
pbar_overall.update(1)
del merged_sd_chunk, original_shapes_chunk, chunk
gc.collect()
torch.cuda.empty_cache()
print("LoRA merge process completed.")
return merged_lora_sd
def extract_two(a,b,pa,pb,ps):