pluslora for flux
parent
349efad1ec
commit
23015a41f9
|
|
@ -36,8 +36,8 @@ from scripts.mergers.bcolors import bcolors
|
|||
import collections
|
||||
|
||||
PREFIXFIX = ("double_blocks","single_blocks","time_in","vector_in","txt_in")
|
||||
BNB = ".quant_state.bitsandbytes__"
|
||||
PREFIX_M = "model.diffusion_model."
|
||||
BNB = ".quant_state.bitsandbytes__"
|
||||
QTYPES = ["fp4", "nf4"]
|
||||
|
||||
try:
|
||||
|
|
@ -566,9 +566,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
return "STOPPED", *NON4
|
||||
|
||||
if need_revert:
|
||||
keys = list(theta_0.keys())
|
||||
for key in keys:
|
||||
theta_0[key.replace(PREFIX_M,"")] = theta_0.pop(key)
|
||||
prefixer(theta_0, True)
|
||||
|
||||
currentmodel = makemodelname(weights_a,weights_b,model_a, model_b,model_c, base_alpha,base_beta,useblocks,mode,calcmode)
|
||||
|
||||
|
|
@ -1730,13 +1728,21 @@ def split_state_dict(sd, additional_state_dicts: list = None):
|
|||
|
||||
return state_dict, guess
|
||||
|
||||
def prefixer(t):
|
||||
def prefixer(t, revert = False):
|
||||
keys = list(t.keys())
|
||||
if revert:
|
||||
for key in keys:
|
||||
t[key.replace(PREFIX_M,"")] = t.pop(key)
|
||||
print('"model.diffusion_model." removed from prifix.')
|
||||
return
|
||||
|
||||
need_revert = False
|
||||
for key in keys:
|
||||
if key.startswith(PREFIXFIX):
|
||||
t["model.diffusion_model." + key] = t.pop(key)
|
||||
need_revert = True
|
||||
if need_revert:
|
||||
print('"model.diffusion_model." added to prifix.')
|
||||
gc.collect()
|
||||
return need_revert
|
||||
|
||||
|
|
@ -1779,7 +1785,7 @@ def to_qdtype(sd_1, sd_2, qd_1, qd_2, device, m1, m2):
|
|||
|
||||
devices.torch_gc()
|
||||
|
||||
def q_dequantize(sd,qtype,device,dtype):
|
||||
def q_dequantize(sd,qtype,device,dtype,setbnb = True):
|
||||
dellist = []
|
||||
from bitsandbytes.functional import dequantize_4bit
|
||||
for key in tqdm(sd):
|
||||
|
|
@ -1788,7 +1794,7 @@ def q_dequantize(sd,qtype,device,dtype):
|
|||
out = torch.empty(qs["shape"],device="cuda:0")
|
||||
sd[key] = dequantize_4bit(sd[key].to("cuda:0"),out=out, absmax=sd[key + ".absmax"].to("cuda:0"),blocksize=qs["blocksize"],quant_type=qs["quant_type"]).to(device,dtype)
|
||||
dellist.append(key + ".absmax")
|
||||
dellist.append(key + BNB + qtype)
|
||||
if setbnb:dellist.append(key + BNB + qtype)
|
||||
dellist.append(key + ".quant_map")
|
||||
elif isinstance(sd[key], torch.Tensor):
|
||||
sd[key] = sd[key].to(dtype)
|
||||
|
|
@ -1797,9 +1803,17 @@ def q_dequantize(sd,qtype,device,dtype):
|
|||
if key in sd:
|
||||
del sd[key]
|
||||
|
||||
def q_quantize(weight,qtype):
|
||||
def q_quantize(sd:dict,qtype,device,setbnb = True):
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
return quantize_4bit(weight, quant_type=qtype)
|
||||
sd_plus = {}
|
||||
for key in tqdm(sd):
|
||||
if "weight" in key and "weight." not in key:
|
||||
weight, state = quantize_4bit(sd[key].to("cuda:0"), quant_type=qtype)
|
||||
sd[key] = weight.to(device)
|
||||
sd_plus[key + ".absmax"] = state.absmax
|
||||
sd_plus[key + ".quant_map"] = state.code
|
||||
if setbnb: sd_plus[key + BNB + qtype] = state.as_dict(True)["quant_state." + "bitsandbytes__" + qtype]
|
||||
sd.update(sd_plus)
|
||||
|
||||
def q_tensor_to_dict(tensor):
|
||||
num_list = tensor.tolist()
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from safetensors.torch import load_file, save_file
|
|||
from scripts.kohyas import extract_lora_from_models as ext
|
||||
from scripts.A1111 import networks as nets
|
||||
from scripts.mergers.model_util import filenamecutter, savemodel
|
||||
from scripts.mergers.mergers import extract_super, unload_forge
|
||||
from scripts.mergers.mergers import extract_super, unload_forge, q_dequantize, q_quantize, qdtyper, prefixer
|
||||
from tqdm import tqdm
|
||||
from modules.ui import versions_html
|
||||
|
||||
|
|
@ -757,11 +757,18 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
|
|||
print(f"Loading {model}")
|
||||
|
||||
theta_0 = read_model_state_dict(checkpoint_info, device)
|
||||
dtype = qdtyper(theta_0)
|
||||
|
||||
if dtype == "fp4" or dtype == "nf4":
|
||||
print(f"Changing dtype of {model} from {dtype} to float16")
|
||||
qkeys = list(theta_0.keys())
|
||||
q_dequantize(theta_0,dtype,device,torch.float16,False)
|
||||
|
||||
isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in theta_0.keys()
|
||||
isv2 = "cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight" in theta_0.keys()
|
||||
isflux = any("double_block" in k for k in theta_0.keys())
|
||||
|
||||
need_revert = prefixer(theta_0) if isflux else False
|
||||
|
||||
try:
|
||||
import networks
|
||||
is15 = True
|
||||
|
|
@ -793,6 +800,20 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
|
|||
|
||||
theta_0 = newpluslora(theta_0,filenames,lweis,names, calc_precision, isxl,isv2, keychanger)
|
||||
|
||||
if dtype == "fp4" or dtype == "nf4":
|
||||
print(f"Changing dtype of {model} from float16 to {dtype}")
|
||||
q_quantize(theta_0,dtype,device,False)
|
||||
|
||||
failedkeys = []
|
||||
for key in theta_0:
|
||||
if key not in qkeys:
|
||||
failedkeys.append(key)
|
||||
|
||||
print(f"Key Check : {'OK' if failedkeys == [] else str(len(failedkeys)) + ' keys failed'}")
|
||||
|
||||
if need_revert:
|
||||
prefixer(theta_0, True)
|
||||
|
||||
if orig_checkpoint:
|
||||
sd_models.reload_model_weights(info=orig_checkpoint)
|
||||
else:
|
||||
|
|
@ -883,6 +904,7 @@ def newpluslora(theta_0,filenames,lweis,names, calc_precision,isxl,isv2, keychan
|
|||
changed = True
|
||||
if not changed: "ERROR: {name}weight is not changed"
|
||||
|
||||
errormodules = []
|
||||
for net in nets.loaded_networks:
|
||||
net.dyn_dim = None
|
||||
for name,module in tqdm(net.modules.items(), desc=f"{net.name}"):
|
||||
|
|
@ -915,7 +937,10 @@ def newpluslora(theta_0,filenames,lweis,names, calc_precision,isxl,isv2, keychan
|
|||
else:
|
||||
theta_0[keychanger[inkey]] ,theta_0[keychanger[outkey]], _= plusweightsqvk(theta_0[keychanger[inkey]],theta_0[keychanger[outkey]], name ,module, net)
|
||||
else:
|
||||
print("unchanged key:",msd_key)
|
||||
errormodules.append(msd_key)
|
||||
|
||||
if errormodules != []:
|
||||
print(f"Unmerged modules in {net.name} : {errormodules}")
|
||||
gc.collect()
|
||||
return theta_0
|
||||
|
||||
|
|
@ -1001,7 +1026,7 @@ def lbw(lora,lwei,isv2):
|
|||
print("unkwon LoRA")
|
||||
|
||||
if errormodules:
|
||||
print("unchanged modules:", errormodules)
|
||||
print("unchanged modules in lbw:", errormodules)
|
||||
else:
|
||||
print(f"{lora.name}: Successfully set the ratio {lwei} ")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue