fix traindiff
parent
eed0ec79a9
commit
54c7ea335a
|
|
@ -397,28 +397,15 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
key_and_alpha = {}
|
||||
|
||||
for num, key in enumerate(tqdm(theta_0.keys(), desc="Stage 1/2") if not False else theta_0.keys()):
|
||||
if "weight." in key: continue #flux quantize
|
||||
|
||||
if stopmerge: return "STOPPED", *NON4
|
||||
if isflux:
|
||||
if key not in theta_1: continue
|
||||
else:
|
||||
if not ("model" in key and key in theta_1): continue
|
||||
if (isflux and key not in theta_1) or (not isflux and not ("model" in key and key in theta_1)):continue
|
||||
if not ("weight" in key or "bias" in key): continue
|
||||
if calcmode == "trainDifference" or calcmode == "extract":
|
||||
if key not in theta_2:
|
||||
continue
|
||||
else:
|
||||
if usebeta and (not key in theta_2) and (not theta_2 == {}) :
|
||||
continue
|
||||
if theta_2 is not None and key not in theta_2: continue
|
||||
|
||||
theta_0[key] = theta_0[key].to(device)
|
||||
theta_1[key] = theta_1[key].to(device)
|
||||
|
||||
try:
|
||||
if theta_2 is not None:
|
||||
theta_2[key] = theta_2[key].to(device)
|
||||
except Exception as e:
|
||||
pass # Do nothing
|
||||
|
||||
weight_index = -1
|
||||
current_alpha = alpha
|
||||
|
|
@ -532,7 +519,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
filtered_diff = scipy.ndimage.median_filter(theta_1[key].to(torch.float32).cpu().numpy(), size=3)
|
||||
# Apply Gaussian filter to the filtered differences
|
||||
filtered_diff = scipy.ndimage.gaussian_filter(filtered_diff, sigma=1)
|
||||
theta_1[key] = torch.tensor(filtered_diff)
|
||||
theta_1[key] = torch.tensor(filtered_diff).to(theta_0[key].device)
|
||||
# Add the filtered differences to the original weights
|
||||
theta_0[key] = theta_0[key] + current_alpha * theta_1[key]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue