pull/464/head
hako-mikan 2025-01-20 23:24:58 +09:00
parent dca1b990c0
commit 54e0ddcaf8
1 changed files with 107 additions and 86 deletions

View File

@ -38,6 +38,7 @@ import collections
PREFIXFIX = ("double_blocks","single_blocks","time_in","vector_in","txt_in")
BNB = ".quant_state.bitsandbytes__"
PREFIX_M = "model.diffusion_model."
QTYPES = ["fp4", "nf4"]
try:
ui_version = int(launch.git_tag().split("-",1)[0].replace("v","").replace(".",""))
@ -132,7 +133,6 @@ def smergegen(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,m
model_loader(checkpoint_info, theta_0, metadata, currentmodel)
cachedealer(False)
del theta_0
@ -214,6 +214,8 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
useblocks,custom_name,save_sets,id_sets,wpresets,deep,fine,bake_in_vae,opt_value,inex,ex_blocks,ex_elems,deepprint,lucks,main = [False,False,False]):
caster("merge start",hearm)
theta_0 = theta_1 = theta_2 = None
qdtypes = [None, None, None]
global hear,mergedmodel,stopmerge,statistics, revert_target
stopmerge = False
@ -293,6 +295,15 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
if model_a =="" or model_b =="" or ((not MODES[0] in mode) and model_c=="") :
return "ERROR: Necessary model is not selected",*NON4
#exclude/include
ex_elems = ex_elems.split(",")
#adjust
if fine.rstrip(",0") != "":
fine = fineman(fine,isxl)
else:
fine = ""
#for MBW text to list
if useblocks:
weights_a_t=weights_a.split(',',1)
@ -306,19 +317,17 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
weights_b = [float(w) for w in weights_b_t[1].split(',')]
caster(f"from {weights_b_t}, beta = {base_beta},weights_a ={weights_b}",hearm)
if not(len(weights_b) == 25 or len(weights_b) == 19 or len(weights_a) == 60): return f"ERROR: weights beta value must be 20 or 26 or 61.",*NON4
caster("model load start",hearm)
#model loading start
caster("Model loading start",hearm)
printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems,device)
theta_1 = load_model_weights_m(model_b,2,cachetarget,device).copy()
qdtypes[1] = qdtyper(theta_1)
prefixer(theta_1)
isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in theta_1.keys()
#adjust
if fine.rstrip(",0") != "":
fine = fineman(fine,isxl)
else:
fine = ""
isflux = any("double_block" in k for k in theta_1.keys())
if isxl and useblocks:
if len(weights_a) == 25:
@ -331,22 +340,31 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
if len(weights_a) == 19: weights_a = weights_a + [0]
if len(weights_b) == 19: weights_b = weights_b + [0]
if MODES[1] in mode:#Add
if stopmerge: return "STOPPED", *NON4
if calcmode == "trainDifference" or calcmode == "extract":
theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
else:
theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
for key in tqdm(theta_1.keys()):
if stopmerge: return "STOPPED", *NON4
if not (MODES[0] in mode): #Add, Twice, Triple
theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
prefixer(theta_2)
qdtypes[2] = qdtyper(theta_2)
if MODES[1] in mode: #Add
if not(calcmode == "trainDifference" or calcmode == "extract"):
if isflux and qdtypes[1] != qdtypes[2]:
to_qdtype(theta_1, theta_2, qdtypes[1], qdtypes[2], device, "Model B", "Model C")
for key in tqdm(theta_1, desc="Stage 0/2, Add difference"):
if 'model' in key:
if stopmerge: return "STOPPED", *NON4
if not ("weight" in key or "bias" in key): continue
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_1[key]- t2
if uselerp:
theta_1[key] = torch.lerp(theta_1[key].to(torch.float32), -theta_2[key].to(torch.float32), 1.0).to(theta_1[key].dtype)
else:
theta_1[key] = (theta_1[key].to(torch.float32) -theta_2[key].to(torch.float32)).to(theta_1[key].dtype)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2
if stopmerge: return "STOPPED", *NON4
del theta_2
theta_2 = None
devices.torch_gc()
if "tensor" in calcmode or "self" in calcmode:
theta_t = load_model_weights_m(model_a,1,cachetarget,device).copy()
@ -355,44 +373,37 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
theta_0[key] = theta_t[key].clone()
del theta_t
else:
theta_0=load_model_weights_m(model_a,1,cachetarget,device).copy()
theta_0 = load_model_weights_m(model_a,1,cachetarget,device).copy()
if MODES[2] in mode or MODES[3] in mode:#Tripe or Twice
theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
else:
if not (calcmode == "trainDifference" or calcmode == "extract"):
theta_2 = {}
qdtypes[0] = qdtyper(theta_0)
need_revert = prefixer(theta_0)
alpha = base_alpha
beta = base_beta
print(f"Model precisions : {qdtypes}")
ex_elems = ex_elems.split(",")
if qdtypes[0] != qdtypes[1]:
print(f"Precision of model B (or B-C) is changing to {qdtypes[0]}...")
to_qdtype(theta_0, theta_1, qdtypes[0], qdtypes[1], device, "Model A", "Model B")
keyratio = []
key_and_alpha = {}
if theta_2 is not None:
to_qdtype(theta_0, theta_2, qdtypes[0], qdtypes[2], device, "Model A", "Model C")
#flux, quantize
flux = any("double_block" in k for k in theta_0.keys())
if flux:
need_revert = prefixer(theta_0)
prefixer(theta_1)
prefixer(theta_2)
qtype = [q_type(theta_0),q_type(theta_1),q_type(theta_2)]
else:
qtype = [False,False,False]
need_revert = False
##### Stage 0/2 in Cosine
if "cosine" in calcmode:
sim, sims = precosine("A" in calcmode,theta_0,theta_1)
##### Stage 1/2
alpha = base_alpha
beta = base_beta
keyratio = []
key_and_alpha = {}
for num, key in enumerate(tqdm(theta_0.keys(), desc="Stage 1/2") if not False else theta_0.keys()):
if "weight." in key: continue #flux quantize
if stopmerge: return "STOPPED", *NON4
if flux:
if isflux:
if key not in theta_1: continue
else:
if not ("model" in key and key in theta_1): continue
@ -404,17 +415,6 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
if usebeta and (not key in theta_2) and (not theta_2 == {}) :
continue
##### Dequantize
if flux and qtype[0] and "weight" in key:
theta_0[key] = q_dequantize(theta_0,key,qtype[0],device)
#print("Dequantize Model A")
if flux and qtype[1] and "weight" in key:
theta_1[key] = q_dequantize(theta_1,key,qtype[1],device).to(theta_0[key].device)
#print(key,"Dequantize Model B")
if theta_2 != {} and qtype[2] and "weight" in key:
theta_2[key] = q_dequantize(theta_2,key,qtype[2],device).to(theta_0[key].device)
#print("Dequantize Model C")
theta_0[key] = theta_0[key].to(device)
theta_1[key] = theta_1[key].to(device)
@ -445,10 +445,10 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
assert a[1] == 9 and b[1] == 4, f"Bad dimensions for merged layer {key}: A={a}, B={b}"
result_is_inpainting_model = True
block,blocks26 = blockfromkey(key,isxl,flux)
block,blocks26 = blockfromkey(key,isxl,isflux)
#if block == "Not Merge": continue
if inex != "Off" and (ex_blocks or (ex_elems != [""])) and excluder(blocks26,inex,ex_blocks,ex_elems,key): continue
if flux and blocks26 in BLOCKIDFLUX:
if isflux and blocks26 in BLOCKIDFLUX:
weight_index = BLOCKIDFLUX.index(blocks26)
elif isxl and blocks26 in BLOCKIDXLL:
weight_index = BLOCKIDXLL.index(blocks26)
@ -478,7 +478,10 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
if MODES[1] in mode:#Add
caster(f"{num}, {block}, {model_a}+{current_alpha}+*({model_b}-{model_c}),{key}",hear)
theta_0_a = theta_0_a + current_alpha * theta_1[key]
if uselerp:
theta_0_a = torch.lerp(theta_0_a.to(torch.float32),theta_0_a.to(torch.float32) + current_alpha * theta_1[key].to(torch.float32),1.0).to(theta_0_a.dtype)
else:
theta_0_a = (theta_0_a.to(torch.float32) + current_alpha * theta_1[key].to(torch.float32)).to(theta_0_a.dtype)
elif MODES[2] in mode:#Triple
caster(f"{num}, {block}, {model_a}+{1-current_alpha-current_beta}+{model_b}*{current_alpha}+ {model_c}*{current_beta}",hear)
@ -565,7 +568,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
else :theta_0[key] =theta_0[key] + torch.tensor(fine[5]).to(theta_0[key].device)
##### del quantize info
if flux and not calcmode == "smoothAdd MT":
if isflux and not calcmode == "smoothAdd MT":
theta_1[key] = None
del theta_1[key]
@ -574,18 +577,6 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
theta_1[key] = theta_1[key].to("cpu")
except:
pass
#flux
if qtype[0]:
dellist = []
for key in theta_0.keys():
if "weight" in key and "weight." not in key:
dellist.append(key + ".absmax")
dellist.append(key + BNB + qtype[0])
dellist.append(key + ".quant_map")
for key in dellist:
if key in theta_0: del theta_0[key]
if calcmode == "smoothAdd MT":
# setting threads to higher than 8 doesn't significantly affect the time for merging
@ -600,12 +591,11 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
keys = list(theta_0.keys())
for key in keys:
theta_0[key.replace(PREFIX_M,"")] = theta_0.pop(key)
print(key)
currentmodel = makemodelname(weights_a,weights_b,model_a, model_b,model_c, base_alpha,base_beta,useblocks,mode,calcmode)
for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
if key in CHCKPOINT_DICT_SKIP_ON_MERGE or flux:
if key in CHCKPOINT_DICT_SKIP_ON_MERGE or isflux:
continue
if "model" in key and key not in theta_0:
theta_0.update({key:theta_1[key]})
@ -1376,8 +1366,8 @@ def blocker(blocks,blockids):
return output
def blockfromkey(key,isxl,flux=False):
if not isxl and not flux:
def blockfromkey(key,isxl,isflux=False):
if not isxl and not isflux:
re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
@ -1409,7 +1399,7 @@ def blockfromkey(key,isxl,flux=False):
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
return BLOCKID[weight_index+1] ,BLOCKID[weight_index+1]
elif flux:
elif isflux:
# Extract the two-digit number using regex
if "vae" in key:
return "VAE", "Not Merge"
@ -1603,6 +1593,8 @@ def model_loader(checkpoint_info, state_dict,metadata, currentmodel):
sd_models.model_data.__init__()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
else:
memory_management.free_memory(1e30,torch.device("cpu"))
load_forge_model(state_dict,checkpoint_info)
################################################
@ -1765,23 +1757,52 @@ def forge_save(filename):
###############################################################
######## QLoRA
def q_type(theta_0):
if any("fp4" in k for k in theta_0.keys()):
def qdtyper(sd):
if any("fp4" in k for k in sd):
return "fp4"
elif any("nf4" in k for k in theta_0.keys()):
elif any("nf4" in k for k in sd):
return "nf4"
for key in sd:
if hasattr(sd[key],"dtype"):
return sd[key].dtype
def q_dequantize(sd,key,qtype,device):
def to_qdtype(sd_1, sd_2, qd_1, qd_2, device, m1, m2):
if qd_1 in QTYPES and qd_2 in QTYPES:
t1 = t2 = torch.float16
else:
t1 = t2 = None
if qd_1 in QTYPES:
print(f"Changing dtype of {m1} to {qd_2 if t1 is None else t1}")
q_dequantize(sd_1,qd_1,device,qd_2)
if qd_2 in QTYPES:
print(f"Changing dtype of {m2} to {qd_1 if t2 is None else t2}")
q_dequantize(sd_2,qd_2,device,qd_1)
devices.torch_gc()
def q_dequantize(sd,qtype,device,dtype):
dellist = []
from bitsandbytes.functional import dequantize_4bit
qs = q_tensor_to_dict(sd[key + BNB + qtype])
out = torch.empty(qs["shape"],device=device)
weight = dequantize_4bit(sd[key].to(device),out=out, absmax=sd[key + ".absmax"].to(device),blocksize=qs["blocksize"],quant_type=qs["quant_type"])
return weight.to(torch.float16)
for key in tqdm(sd):
if ("weight" in key) and ("weight." not in key) and (key + BNB + qtype in sd):
qs = q_tensor_to_dict(sd[key + BNB + qtype])
out = torch.empty(qs["shape"],device="cuda:0")
sd[key] = dequantize_4bit(sd[key].to("cuda:0"),out=out, absmax=sd[key + ".absmax"].to("cuda:0"),blocksize=qs["blocksize"],quant_type=qs["quant_type"]).to(device,dtype)
dellist.append(key + ".absmax")
dellist.append(key + BNB + qtype)
dellist.append(key + ".quant_map")
elif isinstance(sd[key], torch.Tensor):
sd[key] = sd[key].to(dtype)
def q_quantize(weight,qtype,shape):
for key in dellist:
if key in sd:
del sd[key]
def q_quantize(weight,qtype):
from bitsandbytes.functional import quantize_4bit
weight, state = quantize_4bit(weight,quant_type=qtype)
return weight.to("cpu"), state.absmax.to("cpu")
return quantize_4bit(weight, quant_type=qtype)
def q_tensor_to_dict(tensor):
num_list = tensor.tolist()