fix traindiff

pull/464/head
hako-mikan 2025-01-21 20:01:53 +09:00
parent eed0ec79a9
commit 54c7ea335a
1 changed files with 4 additions and 17 deletions

View File

@ -397,28 +397,15 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
key_and_alpha = {}
for num, key in enumerate(tqdm(theta_0.keys(), desc="Stage 1/2") if not False else theta_0.keys()):
if "weight." in key: continue #flux quantize
if stopmerge: return "STOPPED", *NON4
if isflux:
if key not in theta_1: continue
else:
if not ("model" in key and key in theta_1): continue
if (isflux and key not in theta_1) or (not isflux and not ("model" in key and key in theta_1)):continue
if not ("weight" in key or "bias" in key): continue
if calcmode == "trainDifference" or calcmode == "extract":
if key not in theta_2:
continue
else:
if usebeta and (not key in theta_2) and (not theta_2 == {}) :
continue
if theta_2 is not None and key not in theta_2: continue
theta_0[key] = theta_0[key].to(device)
theta_1[key] = theta_1[key].to(device)
try:
if theta_2 is not None:
theta_2[key] = theta_2[key].to(device)
except Exception as e:
pass # Do nothing
weight_index = -1
current_alpha = alpha
@ -532,7 +519,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
filtered_diff = scipy.ndimage.median_filter(theta_1[key].to(torch.float32).cpu().numpy(), size=3)
# Apply Gaussian filter to the filtered differences
filtered_diff = scipy.ndimage.gaussian_filter(filtered_diff, sigma=1)
theta_1[key] = torch.tensor(filtered_diff)
theta_1[key] = torch.tensor(filtered_diff).to(theta_0[key].device)
# Add the filtered differences to the original weights
theta_0[key] = theta_0[key] + current_alpha * theta_1[key]