pluslora for flux

pull/464/head
hako-mikan 2025-01-23 23:15:27 +09:00
parent 349efad1ec
commit 23015a41f9
2 changed files with 52 additions and 13 deletions

View File

@ -36,8 +36,8 @@ from scripts.mergers.bcolors import bcolors
import collections
PREFIXFIX = ("double_blocks","single_blocks","time_in","vector_in","txt_in")
BNB = ".quant_state.bitsandbytes__"
PREFIX_M = "model.diffusion_model."
BNB = ".quant_state.bitsandbytes__"
QTYPES = ["fp4", "nf4"]
try:
@ -566,9 +566,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
return "STOPPED", *NON4
if need_revert:
keys = list(theta_0.keys())
for key in keys:
theta_0[key.replace(PREFIX_M,"")] = theta_0.pop(key)
prefixer(theta_0, True)
currentmodel = makemodelname(weights_a,weights_b,model_a, model_b,model_c, base_alpha,base_beta,useblocks,mode,calcmode)
@ -1730,13 +1728,21 @@ def split_state_dict(sd, additional_state_dicts: list = None):
return state_dict, guess
def prefixer(t):
def prefixer(t, revert = False):
keys = list(t.keys())
if revert:
for key in keys:
t[key.replace(PREFIX_M,"")] = t.pop(key)
print('"model.diffusion_model." removed from prifix.')
return
need_revert = False
for key in keys:
if key.startswith(PREFIXFIX):
t["model.diffusion_model." + key] = t.pop(key)
need_revert = True
if need_revert:
print('"model.diffusion_model." added to prifix.')
gc.collect()
return need_revert
@ -1779,7 +1785,7 @@ def to_qdtype(sd_1, sd_2, qd_1, qd_2, device, m1, m2):
devices.torch_gc()
def q_dequantize(sd,qtype,device,dtype):
def q_dequantize(sd,qtype,device,dtype,setbnb = True):
dellist = []
from bitsandbytes.functional import dequantize_4bit
for key in tqdm(sd):
@ -1788,7 +1794,7 @@ def q_dequantize(sd,qtype,device,dtype):
out = torch.empty(qs["shape"],device="cuda:0")
sd[key] = dequantize_4bit(sd[key].to("cuda:0"),out=out, absmax=sd[key + ".absmax"].to("cuda:0"),blocksize=qs["blocksize"],quant_type=qs["quant_type"]).to(device,dtype)
dellist.append(key + ".absmax")
dellist.append(key + BNB + qtype)
if setbnb:dellist.append(key + BNB + qtype)
dellist.append(key + ".quant_map")
elif isinstance(sd[key], torch.Tensor):
sd[key] = sd[key].to(dtype)
@ -1797,9 +1803,17 @@ def q_dequantize(sd,qtype,device,dtype):
if key in sd:
del sd[key]
def q_quantize(weight,qtype):
def q_quantize(sd:dict,qtype,device,setbnb = True):
from bitsandbytes.functional import quantize_4bit
return quantize_4bit(weight, quant_type=qtype)
sd_plus = {}
for key in tqdm(sd):
if "weight" in key and "weight." not in key:
weight, state = quantize_4bit(sd[key].to("cuda:0"), quant_type=qtype)
sd[key] = weight.to(device)
sd_plus[key + ".absmax"] = state.absmax
sd_plus[key + ".quant_map"] = state.code
if setbnb: sd_plus[key + BNB + qtype] = state.as_dict(True)["quant_state." + "bitsandbytes__" + qtype]
sd.update(sd_plus)
def q_tensor_to_dict(tensor):
num_list = tensor.tolist()

View File

@ -21,7 +21,7 @@ from safetensors.torch import load_file, save_file
from scripts.kohyas import extract_lora_from_models as ext
from scripts.A1111 import networks as nets
from scripts.mergers.model_util import filenamecutter, savemodel
from scripts.mergers.mergers import extract_super, unload_forge
from scripts.mergers.mergers import extract_super, unload_forge, q_dequantize, q_quantize, qdtyper, prefixer
from tqdm import tqdm
from modules.ui import versions_html
@ -757,11 +757,18 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
print(f"Loading {model}")
theta_0 = read_model_state_dict(checkpoint_info, device)
dtype = qdtyper(theta_0)
if dtype == "fp4" or dtype == "nf4":
print(f"Changing dtype of {model} from {dtype} to float16")
qkeys = list(theta_0.keys())
q_dequantize(theta_0,dtype,device,torch.float16,False)
isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in theta_0.keys()
isv2 = "cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight" in theta_0.keys()
isflux = any("double_block" in k for k in theta_0.keys())
need_revert = prefixer(theta_0) if isflux else False
try:
import networks
is15 = True
@ -793,6 +800,20 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
theta_0 = newpluslora(theta_0,filenames,lweis,names, calc_precision, isxl,isv2, keychanger)
if dtype == "fp4" or dtype == "nf4":
print(f"Changing dtype of {model} from float16 to {dtype}")
q_quantize(theta_0,dtype,device,False)
failedkeys = []
for key in theta_0:
if key not in qkeys:
failedkeys.append(key)
print(f"Key Check : {'OK' if failedkeys == [] else str(len(failedkeys)) + ' keys failed'}")
if need_revert:
prefixer(theta_0, True)
if orig_checkpoint:
sd_models.reload_model_weights(info=orig_checkpoint)
else:
@ -883,6 +904,7 @@ def newpluslora(theta_0,filenames,lweis,names, calc_precision,isxl,isv2, keychan
changed = True
if not changed: "ERROR: {name}weight is not changed"
errormodules = []
for net in nets.loaded_networks:
net.dyn_dim = None
for name,module in tqdm(net.modules.items(), desc=f"{net.name}"):
@ -915,7 +937,10 @@ def newpluslora(theta_0,filenames,lweis,names, calc_precision,isxl,isv2, keychan
else:
theta_0[keychanger[inkey]] ,theta_0[keychanger[outkey]], _= plusweightsqvk(theta_0[keychanger[inkey]],theta_0[keychanger[outkey]], name ,module, net)
else:
print("unchanged key:",msd_key)
errormodules.append(msd_key)
if errormodules != []:
print(f"Unmerged modules in {net.name} : {errormodules}")
gc.collect()
return theta_0
@ -1001,7 +1026,7 @@ def lbw(lora,lwei,isv2):
print("unkwon LoRA")
if errormodules:
print("unchanged modules:", errormodules)
print("unchanged modules in lbw:", errormodules)
else:
print(f"{lora.name}: Successfully set the ratio {lwei} ")