mirror of https://github.com/vladmandic/automatic
175 lines
4.7 KiB
Python
175 lines
4.7 KiB
Python
# https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion
|
|
from collections import defaultdict
|
|
from random import shuffle
|
|
from typing import NamedTuple
|
|
import torch
|
|
from scipy.optimize import linear_sum_assignment
|
|
from modules.shared import log
|
|
|
|
|
|
SPECIAL_KEYS = [
|
|
"first_stage_model.decoder.norm_out.weight",
|
|
"first_stage_model.decoder.norm_out.bias",
|
|
"first_stage_model.encoder.norm_out.weight",
|
|
"first_stage_model.encoder.norm_out.bias",
|
|
"model.diffusion_model.out.0.weight",
|
|
"model.diffusion_model.out.0.bias",
|
|
]
|
|
|
|
|
|
class PermutationSpec(NamedTuple):
|
|
perm_to_axes: dict
|
|
axes_to_perm: dict
|
|
|
|
|
|
def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
|
|
perm_to_axes = defaultdict(list)
|
|
for wk, axis_perms in axes_to_perm.items():
|
|
for axis, perm in enumerate(axis_perms):
|
|
if perm is not None:
|
|
perm_to_axes[perm].append((wk, axis))
|
|
return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)
|
|
|
|
|
|
def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
|
|
"""Get parameter `k` from `params`, with the permutations applied."""
|
|
w = params[k]
|
|
for axis, p in enumerate(ps.axes_to_perm[k]):
|
|
# Skip the axis we're trying to permute.
|
|
if axis == except_axis:
|
|
continue
|
|
|
|
# None indicates that there is no permutation relevant to that axis.
|
|
if p:
|
|
w = torch.index_select(w, axis, perm[p].int())
|
|
|
|
return w
|
|
|
|
|
|
def apply_permutation(ps: PermutationSpec, perm, params):
|
|
"""Apply a `perm` to `params`."""
|
|
return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}
|
|
|
|
|
|
def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha):
|
|
for k in model_a:
|
|
try:
|
|
perm_params = get_permuted_param(
|
|
ps, perm, k, model_a
|
|
)
|
|
model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params
|
|
except RuntimeError: # dealing with pix2pix and inpainting models
|
|
continue
|
|
return model_a
|
|
|
|
|
|
def inner_matching(
|
|
n,
|
|
ps,
|
|
p,
|
|
params_a,
|
|
params_b,
|
|
usefp16,
|
|
progress,
|
|
number,
|
|
linear_sum,
|
|
perm,
|
|
device,
|
|
):
|
|
A = torch.zeros((n, n), dtype=torch.float16) if usefp16 else torch.zeros((n, n))
|
|
A = A.to(device)
|
|
|
|
for wk, axis in ps.perm_to_axes[p]:
|
|
w_a = params_a[wk]
|
|
w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
|
|
w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device)
|
|
w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device)
|
|
|
|
if usefp16:
|
|
w_a = w_a.half().to(device)
|
|
w_b = w_b.half().to(device)
|
|
|
|
try:
|
|
A += torch.matmul(w_a, w_b)
|
|
except RuntimeError:
|
|
A += torch.matmul(torch.dequantize(w_a), torch.dequantize(w_b))
|
|
|
|
A = A.cpu()
|
|
ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True)
|
|
A = A.to(device)
|
|
|
|
assert (torch.tensor(ri) == torch.arange(len(ri))).all()
|
|
|
|
eye_tensor = torch.eye(n).to(device)
|
|
|
|
oldL = torch.vdot(
|
|
torch.flatten(A).float(), torch.flatten(eye_tensor[perm[p].long()])
|
|
)
|
|
newL = torch.vdot(torch.flatten(A).float(), torch.flatten(eye_tensor[ci, :]))
|
|
|
|
if usefp16:
|
|
oldL = oldL.half()
|
|
newL = newL.half()
|
|
|
|
if newL - oldL != 0:
|
|
linear_sum += abs((newL - oldL).item())
|
|
number += 1
|
|
log.debug(f"Merge Rebasin permutation: {p}={newL-oldL}")
|
|
|
|
progress = progress or newL > oldL + 1e-12
|
|
|
|
perm[p] = torch.Tensor(ci).to(device)
|
|
|
|
return linear_sum, number, perm, progress
|
|
|
|
|
|
def weight_matching(
|
|
ps: PermutationSpec,
|
|
params_a,
|
|
params_b,
|
|
max_iter=1,
|
|
init_perm=None,
|
|
usefp16=False,
|
|
device="cpu",
|
|
):
|
|
perm_sizes = {
|
|
p: params_a[axes[0][0]].shape[axes[0][1]]
|
|
for p, axes in ps.perm_to_axes.items()
|
|
if axes[0][0] in params_a.keys()
|
|
}
|
|
perm = {}
|
|
perm = (
|
|
{p: torch.arange(n).to(device) for p, n in perm_sizes.items()}
|
|
if init_perm is None
|
|
else init_perm
|
|
)
|
|
|
|
linear_sum = 0
|
|
number = 0
|
|
|
|
special_layers = ["P_bg324"]
|
|
for _i in range(max_iter):
|
|
progress = False
|
|
shuffle(special_layers)
|
|
for p in special_layers:
|
|
n = perm_sizes[p]
|
|
linear_sum, number, perm, progress = inner_matching(
|
|
n,
|
|
ps,
|
|
p,
|
|
params_a,
|
|
params_b,
|
|
usefp16,
|
|
progress,
|
|
number,
|
|
linear_sum,
|
|
perm,
|
|
device,
|
|
)
|
|
progress = True
|
|
if not progress:
|
|
break
|
|
|
|
average = linear_sum / number if number > 0 else 0
|
|
return perm, average
|