flux fp8
parent
dca1b990c0
commit
54e0ddcaf8
|
|
@ -38,6 +38,7 @@ import collections
|
|||
PREFIXFIX = ("double_blocks","single_blocks","time_in","vector_in","txt_in")
|
||||
BNB = ".quant_state.bitsandbytes__"
|
||||
PREFIX_M = "model.diffusion_model."
|
||||
QTYPES = ["fp4", "nf4"]
|
||||
|
||||
try:
|
||||
ui_version = int(launch.git_tag().split("-",1)[0].replace("v","").replace(".",""))
|
||||
|
|
@ -132,7 +133,6 @@ def smergegen(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,m
|
|||
|
||||
model_loader(checkpoint_info, theta_0, metadata, currentmodel)
|
||||
|
||||
|
||||
cachedealer(False)
|
||||
|
||||
del theta_0
|
||||
|
|
@ -214,6 +214,8 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
useblocks,custom_name,save_sets,id_sets,wpresets,deep,fine,bake_in_vae,opt_value,inex,ex_blocks,ex_elems,deepprint,lucks,main = [False,False,False]):
|
||||
|
||||
caster("merge start",hearm)
|
||||
theta_0 = theta_1 = theta_2 = None
|
||||
qdtypes = [None, None, None]
|
||||
global hear,mergedmodel,stopmerge,statistics, revert_target
|
||||
stopmerge = False
|
||||
|
||||
|
|
@ -293,6 +295,15 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
if model_a =="" or model_b =="" or ((not MODES[0] in mode) and model_c=="") :
|
||||
return "ERROR: Necessary model is not selected",*NON4
|
||||
|
||||
#exclude/include
|
||||
ex_elems = ex_elems.split(",")
|
||||
|
||||
#adjust
|
||||
if fine.rstrip(",0") != "":
|
||||
fine = fineman(fine,isxl)
|
||||
else:
|
||||
fine = ""
|
||||
|
||||
#for MBW text to list
|
||||
if useblocks:
|
||||
weights_a_t=weights_a.split(',',1)
|
||||
|
|
@ -306,19 +317,17 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
weights_b = [float(w) for w in weights_b_t[1].split(',')]
|
||||
caster(f"from {weights_b_t}, beta = {base_beta},weights_a ={weights_b}",hearm)
|
||||
if not(len(weights_b) == 25 or len(weights_b) == 19 or len(weights_a) == 60): return f"ERROR: weights beta value must be 20 or 26 or 61.",*NON4
|
||||
|
||||
caster("model load start",hearm)
|
||||
|
||||
#model loading start
|
||||
caster("Model loading start",hearm)
|
||||
printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems,device)
|
||||
|
||||
theta_1 = load_model_weights_m(model_b,2,cachetarget,device).copy()
|
||||
qdtypes[1] = qdtyper(theta_1)
|
||||
prefixer(theta_1)
|
||||
|
||||
isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in theta_1.keys()
|
||||
|
||||
#adjust
|
||||
if fine.rstrip(",0") != "":
|
||||
fine = fineman(fine,isxl)
|
||||
else:
|
||||
fine = ""
|
||||
isflux = any("double_block" in k for k in theta_1.keys())
|
||||
|
||||
if isxl and useblocks:
|
||||
if len(weights_a) == 25:
|
||||
|
|
@ -331,22 +340,31 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
if len(weights_a) == 19: weights_a = weights_a + [0]
|
||||
if len(weights_b) == 19: weights_b = weights_b + [0]
|
||||
|
||||
if MODES[1] in mode:#Add
|
||||
if stopmerge: return "STOPPED", *NON4
|
||||
if calcmode == "trainDifference" or calcmode == "extract":
|
||||
theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
|
||||
else:
|
||||
theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
|
||||
for key in tqdm(theta_1.keys()):
|
||||
if stopmerge: return "STOPPED", *NON4
|
||||
if not (MODES[0] in mode): #Add, Twice, Triple
|
||||
theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
|
||||
prefixer(theta_2)
|
||||
qdtypes[2] = qdtyper(theta_2)
|
||||
|
||||
if MODES[1] in mode: #Add
|
||||
if not(calcmode == "trainDifference" or calcmode == "extract"):
|
||||
if isflux and qdtypes[1] != qdtypes[2]:
|
||||
to_qdtype(theta_1, theta_2, qdtypes[1], qdtypes[2], device, "Model B", "Model C")
|
||||
for key in tqdm(theta_1, desc="Stage 0/2, Add difference"):
|
||||
if 'model' in key:
|
||||
if stopmerge: return "STOPPED", *NON4
|
||||
if not ("weight" in key or "bias" in key): continue
|
||||
if key in theta_2:
|
||||
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
||||
theta_1[key] = theta_1[key]- t2
|
||||
if uselerp:
|
||||
theta_1[key] = torch.lerp(theta_1[key].to(torch.float32), -theta_2[key].to(torch.float32), 1.0).to(theta_1[key].dtype)
|
||||
else:
|
||||
theta_1[key] = (theta_1[key].to(torch.float32) -theta_2[key].to(torch.float32)).to(theta_1[key].dtype)
|
||||
else:
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
del theta_2
|
||||
|
||||
if stopmerge: return "STOPPED", *NON4
|
||||
del theta_2
|
||||
theta_2 = None
|
||||
devices.torch_gc()
|
||||
|
||||
if "tensor" in calcmode or "self" in calcmode:
|
||||
theta_t = load_model_weights_m(model_a,1,cachetarget,device).copy()
|
||||
|
|
@ -355,44 +373,37 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
theta_0[key] = theta_t[key].clone()
|
||||
del theta_t
|
||||
else:
|
||||
theta_0=load_model_weights_m(model_a,1,cachetarget,device).copy()
|
||||
theta_0 = load_model_weights_m(model_a,1,cachetarget,device).copy()
|
||||
|
||||
if MODES[2] in mode or MODES[3] in mode:#Tripe or Twice
|
||||
theta_2 = load_model_weights_m(model_c,3,cachetarget,device).copy()
|
||||
else:
|
||||
if not (calcmode == "trainDifference" or calcmode == "extract"):
|
||||
theta_2 = {}
|
||||
qdtypes[0] = qdtyper(theta_0)
|
||||
need_revert = prefixer(theta_0)
|
||||
|
||||
alpha = base_alpha
|
||||
beta = base_beta
|
||||
print(f"Model precisions : {qdtypes}")
|
||||
|
||||
ex_elems = ex_elems.split(",")
|
||||
if qdtypes[0] != qdtypes[1]:
|
||||
print(f"Precision of model B (or B-C) is changing to {qdtypes[0]}...")
|
||||
to_qdtype(theta_0, theta_1, qdtypes[0], qdtypes[1], device, "Model A", "Model B")
|
||||
|
||||
keyratio = []
|
||||
key_and_alpha = {}
|
||||
if theta_2 is not None:
|
||||
to_qdtype(theta_0, theta_2, qdtypes[0], qdtypes[2], device, "Model A", "Model C")
|
||||
|
||||
#flux, quantize
|
||||
flux = any("double_block" in k for k in theta_0.keys())
|
||||
if flux:
|
||||
need_revert = prefixer(theta_0)
|
||||
prefixer(theta_1)
|
||||
prefixer(theta_2)
|
||||
qtype = [q_type(theta_0),q_type(theta_1),q_type(theta_2)]
|
||||
else:
|
||||
qtype = [False,False,False]
|
||||
need_revert = False
|
||||
|
||||
##### Stage 0/2 in Cosine
|
||||
if "cosine" in calcmode:
|
||||
sim, sims = precosine("A" in calcmode,theta_0,theta_1)
|
||||
|
||||
##### Stage 1/2
|
||||
|
||||
alpha = base_alpha
|
||||
beta = base_beta
|
||||
|
||||
keyratio = []
|
||||
key_and_alpha = {}
|
||||
|
||||
for num, key in enumerate(tqdm(theta_0.keys(), desc="Stage 1/2") if not False else theta_0.keys()):
|
||||
if "weight." in key: continue #flux quantize
|
||||
|
||||
if stopmerge: return "STOPPED", *NON4
|
||||
if flux:
|
||||
if isflux:
|
||||
if key not in theta_1: continue
|
||||
else:
|
||||
if not ("model" in key and key in theta_1): continue
|
||||
|
|
@ -404,17 +415,6 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
if usebeta and (not key in theta_2) and (not theta_2 == {}) :
|
||||
continue
|
||||
|
||||
##### Dequantize
|
||||
if flux and qtype[0] and "weight" in key:
|
||||
theta_0[key] = q_dequantize(theta_0,key,qtype[0],device)
|
||||
#print("Dequantize Model A")
|
||||
if flux and qtype[1] and "weight" in key:
|
||||
theta_1[key] = q_dequantize(theta_1,key,qtype[1],device).to(theta_0[key].device)
|
||||
#print(key,"Dequantize Model B")
|
||||
if theta_2 != {} and qtype[2] and "weight" in key:
|
||||
theta_2[key] = q_dequantize(theta_2,key,qtype[2],device).to(theta_0[key].device)
|
||||
#print("Dequantize Model C")
|
||||
|
||||
theta_0[key] = theta_0[key].to(device)
|
||||
theta_1[key] = theta_1[key].to(device)
|
||||
|
||||
|
|
@ -445,10 +445,10 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
assert a[1] == 9 and b[1] == 4, f"Bad dimensions for merged layer {key}: A={a}, B={b}"
|
||||
result_is_inpainting_model = True
|
||||
|
||||
block,blocks26 = blockfromkey(key,isxl,flux)
|
||||
block,blocks26 = blockfromkey(key,isxl,isflux)
|
||||
#if block == "Not Merge": continue
|
||||
if inex != "Off" and (ex_blocks or (ex_elems != [""])) and excluder(blocks26,inex,ex_blocks,ex_elems,key): continue
|
||||
if flux and blocks26 in BLOCKIDFLUX:
|
||||
if isflux and blocks26 in BLOCKIDFLUX:
|
||||
weight_index = BLOCKIDFLUX.index(blocks26)
|
||||
elif isxl and blocks26 in BLOCKIDXLL:
|
||||
weight_index = BLOCKIDXLL.index(blocks26)
|
||||
|
|
@ -478,7 +478,10 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
|
||||
if MODES[1] in mode:#Add
|
||||
caster(f"{num}, {block}, {model_a}+{current_alpha}+*({model_b}-{model_c}),{key}",hear)
|
||||
theta_0_a = theta_0_a + current_alpha * theta_1[key]
|
||||
if uselerp:
|
||||
theta_0_a = torch.lerp(theta_0_a.to(torch.float32),theta_0_a.to(torch.float32) + current_alpha * theta_1[key].to(torch.float32),1.0).to(theta_0_a.dtype)
|
||||
else:
|
||||
theta_0_a = (theta_0_a.to(torch.float32) + current_alpha * theta_1[key].to(torch.float32)).to(theta_0_a.dtype)
|
||||
|
||||
elif MODES[2] in mode:#Triple
|
||||
caster(f"{num}, {block}, {model_a}+{1-current_alpha-current_beta}+{model_b}*{current_alpha}+ {model_c}*{current_beta}",hear)
|
||||
|
|
@ -565,7 +568,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
else :theta_0[key] =theta_0[key] + torch.tensor(fine[5]).to(theta_0[key].device)
|
||||
|
||||
##### del quantize info
|
||||
if flux and not calcmode == "smoothAdd MT":
|
||||
if isflux and not calcmode == "smoothAdd MT":
|
||||
theta_1[key] = None
|
||||
del theta_1[key]
|
||||
|
||||
|
|
@ -574,18 +577,6 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
theta_1[key] = theta_1[key].to("cpu")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
#flux
|
||||
if qtype[0]:
|
||||
dellist = []
|
||||
for key in theta_0.keys():
|
||||
if "weight" in key and "weight." not in key:
|
||||
dellist.append(key + ".absmax")
|
||||
dellist.append(key + BNB + qtype[0])
|
||||
dellist.append(key + ".quant_map")
|
||||
for key in dellist:
|
||||
if key in theta_0: del theta_0[key]
|
||||
|
||||
if calcmode == "smoothAdd MT":
|
||||
# setting threads to higher than 8 doesn't significantly affect the time for merging
|
||||
|
|
@ -600,12 +591,11 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
keys = list(theta_0.keys())
|
||||
for key in keys:
|
||||
theta_0[key.replace(PREFIX_M,"")] = theta_0.pop(key)
|
||||
print(key)
|
||||
|
||||
currentmodel = makemodelname(weights_a,weights_b,model_a, model_b,model_c, base_alpha,base_beta,useblocks,mode,calcmode)
|
||||
|
||||
for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
|
||||
if key in CHCKPOINT_DICT_SKIP_ON_MERGE or flux:
|
||||
if key in CHCKPOINT_DICT_SKIP_ON_MERGE or isflux:
|
||||
continue
|
||||
if "model" in key and key not in theta_0:
|
||||
theta_0.update({key:theta_1[key]})
|
||||
|
|
@ -1376,8 +1366,8 @@ def blocker(blocks,blockids):
|
|||
return output
|
||||
|
||||
|
||||
def blockfromkey(key,isxl,flux=False):
|
||||
if not isxl and not flux:
|
||||
def blockfromkey(key,isxl,isflux=False):
|
||||
if not isxl and not isflux:
|
||||
re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
|
||||
re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
|
||||
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
|
||||
|
|
@ -1409,7 +1399,7 @@ def blockfromkey(key,isxl,flux=False):
|
|||
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
|
||||
return BLOCKID[weight_index+1] ,BLOCKID[weight_index+1]
|
||||
|
||||
elif flux:
|
||||
elif isflux:
|
||||
# Extract the two-digit number using regex
|
||||
if "vae" in key:
|
||||
return "VAE", "Not Merge"
|
||||
|
|
@ -1603,6 +1593,8 @@ def model_loader(checkpoint_info, state_dict,metadata, currentmodel):
|
|||
sd_models.model_data.__init__()
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
else:
|
||||
memory_management.free_memory(1e30,torch.device("cpu"))
|
||||
|
||||
load_forge_model(state_dict,checkpoint_info)
|
||||
|
||||
################################################
|
||||
|
|
@ -1765,23 +1757,52 @@ def forge_save(filename):
|
|||
|
||||
###############################################################
|
||||
######## QLoRA
|
||||
def q_type(theta_0):
|
||||
if any("fp4" in k for k in theta_0.keys()):
|
||||
def qdtyper(sd):
|
||||
if any("fp4" in k for k in sd):
|
||||
return "fp4"
|
||||
elif any("nf4" in k for k in theta_0.keys()):
|
||||
elif any("nf4" in k for k in sd):
|
||||
return "nf4"
|
||||
for key in sd:
|
||||
if hasattr(sd[key],"dtype"):
|
||||
return sd[key].dtype
|
||||
|
||||
def q_dequantize(sd,key,qtype,device):
|
||||
def to_qdtype(sd_1, sd_2, qd_1, qd_2, device, m1, m2):
|
||||
if qd_1 in QTYPES and qd_2 in QTYPES:
|
||||
t1 = t2 = torch.float16
|
||||
else:
|
||||
t1 = t2 = None
|
||||
|
||||
if qd_1 in QTYPES:
|
||||
print(f"Changing dtype of {m1} to {qd_2 if t1 is None else t1}")
|
||||
q_dequantize(sd_1,qd_1,device,qd_2)
|
||||
|
||||
if qd_2 in QTYPES:
|
||||
print(f"Changing dtype of {m2} to {qd_1 if t2 is None else t2}")
|
||||
q_dequantize(sd_2,qd_2,device,qd_1)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
def q_dequantize(sd,qtype,device,dtype):
|
||||
dellist = []
|
||||
from bitsandbytes.functional import dequantize_4bit
|
||||
qs = q_tensor_to_dict(sd[key + BNB + qtype])
|
||||
out = torch.empty(qs["shape"],device=device)
|
||||
weight = dequantize_4bit(sd[key].to(device),out=out, absmax=sd[key + ".absmax"].to(device),blocksize=qs["blocksize"],quant_type=qs["quant_type"])
|
||||
return weight.to(torch.float16)
|
||||
for key in tqdm(sd):
|
||||
if ("weight" in key) and ("weight." not in key) and (key + BNB + qtype in sd):
|
||||
qs = q_tensor_to_dict(sd[key + BNB + qtype])
|
||||
out = torch.empty(qs["shape"],device="cuda:0")
|
||||
sd[key] = dequantize_4bit(sd[key].to("cuda:0"),out=out, absmax=sd[key + ".absmax"].to("cuda:0"),blocksize=qs["blocksize"],quant_type=qs["quant_type"]).to(device,dtype)
|
||||
dellist.append(key + ".absmax")
|
||||
dellist.append(key + BNB + qtype)
|
||||
dellist.append(key + ".quant_map")
|
||||
elif isinstance(sd[key], torch.Tensor):
|
||||
sd[key] = sd[key].to(dtype)
|
||||
|
||||
def q_quantize(weight,qtype,shape):
|
||||
for key in dellist:
|
||||
if key in sd:
|
||||
del sd[key]
|
||||
|
||||
def q_quantize(weight,qtype):
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
weight, state = quantize_4bit(weight,quant_type=qtype)
|
||||
return weight.to("cpu"), state.absmax.to("cpu")
|
||||
return quantize_4bit(weight, quant_type=qtype)
|
||||
|
||||
def q_tensor_to_dict(tensor):
|
||||
num_list = tensor.tolist()
|
||||
|
|
|
|||
Loading…
Reference in New Issue