ver21
parent
b165dcfd58
commit
5744e6fef6
|
|
@ -406,19 +406,19 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
|
||||
##### Dequantize
|
||||
if flux and qtype[0] and "weight" in key:
|
||||
theta_0[key] = q_dequantize(theta_0,key,qtype[0])
|
||||
theta_0[key] = q_dequantize(theta_0,key,qtype[0],device)
|
||||
#print("Dequantize Model A")
|
||||
if flux and qtype[1] and "weight" in key:
|
||||
theta_1[key] = q_dequantize(theta_1,key,qtype[1]).to(theta_0[key].device)
|
||||
theta_1[key] = q_dequantize(theta_1,key,qtype[1],device).to(theta_0[key].device)
|
||||
#print(key,"Dequantize Model B")
|
||||
if theta_2 != {} and qtype[2] and "weight" in key:
|
||||
theta_2[key] = q_dequantize(theta_2,key,qtype[2]).to(theta_0[key].device)
|
||||
theta_2[key] = q_dequantize(theta_2,key,qtype[2],device).to(theta_0[key].device)
|
||||
#print("Dequantize Model C")
|
||||
|
||||
theta_0[key] = theta_0[key].to("cuda")
|
||||
theta_1[key] = theta_1[key].to("cuda")
|
||||
theta_0[key] = theta_0[key].to(device)
|
||||
theta_1[key] = theta_1[key].to(device)
|
||||
try:
|
||||
theta_2[key] = theta_2[key].to("cuda")
|
||||
theta_2[key] = theta_2[key].to(device)
|
||||
except Exception as e:
|
||||
pass # Do nothing
|
||||
|
||||
|
|
@ -511,7 +511,8 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
theta_0[key][:, 0:4, :, :] = theta_0_a
|
||||
else:
|
||||
theta_0[key] = theta_0_a
|
||||
|
||||
|
||||
theta_0_a = a = b = None
|
||||
del theta_0_a, a, b
|
||||
|
||||
elif "cosine" in calcmode:
|
||||
|
|
@ -563,10 +564,14 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
else :theta_0[key] =theta_0[key] + torch.tensor(fine[5]).to(theta_0[key].device)
|
||||
|
||||
##### del quantize info
|
||||
if flux and qtype[0] and "weight" in key:
|
||||
theta_0[key] = theta_0[key].to("cpu")
|
||||
if flux and not calcmode == "smoothAdd MT":
|
||||
theta_1[key] = None
|
||||
del theta_1[key]
|
||||
theta_0[key] = theta_0[key].to("cpu")
|
||||
try:
|
||||
theta_1[key] = theta_1[key].to("cpu")
|
||||
except:
|
||||
pass
|
||||
|
||||
#flux
|
||||
if qtype[0]:
|
||||
|
|
@ -602,9 +607,12 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
if "model" in key and key not in theta_0:
|
||||
theta_0.update({key:theta_1[key]})
|
||||
|
||||
theta_1 = None
|
||||
del theta_1
|
||||
if calcmode == "trainDifference" or calcmode == "extract":
|
||||
theta_2 = None
|
||||
del theta_2
|
||||
gc.collect()
|
||||
|
||||
##### BakeVAE
|
||||
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
||||
|
|
@ -616,7 +624,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
theta_0_key = 'first_stage_model.' + key
|
||||
if theta_0_key in theta_0:
|
||||
theta_0[theta_0_key] = vae_dict[key]
|
||||
|
||||
vae_dict = None
|
||||
del vae_dict
|
||||
|
||||
modelid = rwmergelog(currentmodel,mergedmodel)
|
||||
|
|
@ -1652,8 +1660,6 @@ def load_forge_model(state_dict,checkpoint_info = None):
|
|||
unet_storage_dtype=fsd.dynamic_args['forge_unet_storage_dtype']
|
||||
)
|
||||
|
||||
return sd_model, True
|
||||
|
||||
COMP_NAME_AND_PREFIX = {"transformer":PREFIX_M, "text_encoder": "clip_l" , "text_encoder2": "t5xxl", "vae": "vae."}
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
@ -1719,6 +1725,7 @@ def prefixer(t):
|
|||
if key.startswith(PREFIXFIX):
|
||||
t["model.diffusion_model." + key] = t.pop(key)
|
||||
need_revert = True
|
||||
gc.collect()
|
||||
return need_revert
|
||||
|
||||
def forge_save(filename):
|
||||
|
|
@ -1741,12 +1748,12 @@ def q_type(theta_0):
|
|||
elif any("nf4" in k for k in theta_0.keys()):
|
||||
return "nf4"
|
||||
|
||||
def q_dequantize(sd,key,qtype):
|
||||
def q_dequantize(sd,key,qtype,device):
|
||||
from bitsandbytes.functional import dequantize_4bit
|
||||
qs = q_tensor_to_dict(sd[key + BNB + qtype])
|
||||
out = torch.empty(qs["shape"],device="cuda")
|
||||
weight = dequantize_4bit(sd[key].to("cuda"),out=out, absmax=sd[key + ".absmax"].to("cuda"),blocksize=qs["blocksize"],quant_type=qs["quant_type"])
|
||||
return weight
|
||||
out = torch.empty(qs["shape"],device=device)
|
||||
weight = dequantize_4bit(sd[key].to(device),out=out, absmax=sd[key + ".absmax"].to(device),blocksize=qs["blocksize"],quant_type=qs["quant_type"])
|
||||
return weight.to(torch.float16)
|
||||
|
||||
def q_quantize(weight,qtype,shape):
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from modules import extra_networks, scripts, sd_models, lowvram
|
|||
from modules.ui import create_refresh_button
|
||||
from safetensors.torch import load_file, save_file
|
||||
from scripts.kohyas import extract_lora_from_models as ext
|
||||
from scripts.kohyas import lora as klora
|
||||
from scripts.A1111 import networks as nets
|
||||
from scripts.mergers.model_util import (filenamecutter, savemodel)
|
||||
from scripts.mergers.mergers import extract_super, unload_forge
|
||||
|
|
@ -216,7 +215,7 @@ def on_ui_tabs():
|
|||
outs = sorted(set(outs))
|
||||
return outs + outs_list
|
||||
|
||||
def calculatedim(calcsets):
|
||||
def calculatedim(calcsets, device):
|
||||
# CSVから読み込む
|
||||
if "Load from CSV" in calcsets:
|
||||
with open(dimpath, mode='r', encoding='utf-8') as csv_file:
|
||||
|
|
@ -230,7 +229,12 @@ def on_ui_tabs():
|
|||
if name in ldict and ldict[n[0]] != ["","",""]:
|
||||
continue
|
||||
c_lora = lora.available_loras.get(n[0], None)
|
||||
d, t, s = dimgetter(c_lora.filename)
|
||||
|
||||
try:
|
||||
d, t, s = dimgetter(c_lora.filename, device)
|
||||
except:
|
||||
d, t, s = dimgetter(c_lora.filename)
|
||||
|
||||
ldict[name] = [d,t,s]
|
||||
|
||||
# CSVに保存
|
||||
|
|
@ -247,7 +251,7 @@ def on_ui_tabs():
|
|||
|
||||
sml_calcdim.click(
|
||||
fn=calculatedim,
|
||||
inputs=[sml_calcsets],
|
||||
inputs=[sml_calcsets, device],
|
||||
outputs=[sml_loras,sml_dims,sml_dims_xl]
|
||||
)
|
||||
|
||||
|
|
@ -287,22 +291,16 @@ def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,save_precision,calc_
|
|||
return "ERROR: No model Selected"
|
||||
gc.collect()
|
||||
|
||||
try:
|
||||
currentinfo = shared.sd_model.sd_checkpoint_info
|
||||
except:
|
||||
currentinfo = None
|
||||
currentinfo = shared.sd_model.sd_checkpoint_info
|
||||
|
||||
checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
|
||||
load_model(checkpoint_info)
|
||||
|
||||
model = shared.sd_model
|
||||
print(type(model).__name__)
|
||||
print("XL" in type(model).__name__)
|
||||
|
||||
is_sdxl = type(model).__name__ == "StableDiffusionXL" or getattr(model,'is_sdxl', False)
|
||||
is_sd2 = type(model).__name__ == "StableDiffusion2" or getattr(model,'is_sd2', False)
|
||||
is_sd1 = type(model).__name__ == "StableDiffusion" or getattr(model,'is_sd1', False)
|
||||
is_flux = type(model).__name__ == "Flux" or getattr(model,'is_flux', False)
|
||||
is_sdxl = hasattr(model, 'conditioner')
|
||||
is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
||||
is_sd1 = not model.is_sdxl and not model.is_sd2
|
||||
|
||||
print(f"Detected model type: SDXL: {is_sdxl}, SD2.X: {is_sd2}, SD1.X: {is_sd1}")
|
||||
|
||||
|
|
@ -339,8 +337,7 @@ def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,save_precision,calc_
|
|||
|
||||
result = ext.svd(args)
|
||||
|
||||
if currentinfo:
|
||||
load_model(currentinfo)
|
||||
load_model(currentinfo)
|
||||
return result
|
||||
|
||||
##############################################################
|
||||
|
|
@ -396,7 +393,7 @@ def lmerge(loranames,loraratioss,settings,filename,dim,save_precision,calc_preci
|
|||
c_lora = lora.available_loras.get(n[0], None)
|
||||
ln.append(c_lora.filename)
|
||||
lr.append(ratio)
|
||||
d, t, s = dimgetter(c_lora.filename)
|
||||
d, t, s = dimgetter(c_lora.filename, device)
|
||||
if t == "LoCon" and isinstance(d, list):
|
||||
d = list(set(d))
|
||||
d = d[0]
|
||||
|
|
@ -423,23 +420,23 @@ def lmerge(loranames,loraratioss,settings,filename,dim,save_precision,calc_preci
|
|||
if "LyCORIS" in ld:
|
||||
if len(ld) !=1:
|
||||
return "multiple merge of LyCORIS is not supported"
|
||||
sd = lycomerge(ln[0], lr[0], calc_precision)
|
||||
sd = lycomerge(ln[0], lr[0], calc_precision, device)
|
||||
elif dim > 0:
|
||||
print("change demension to ", dim)
|
||||
sd = merge_lora_models_dim(ln, lr, dim,settings,device,calc_precision)
|
||||
sd = merge_lora_models_dim(ln, lr, dim,settings,device,calc_precision, device)
|
||||
elif auto and ld.count(ld[0]) != len(ld):
|
||||
print("change demension to ",dmax)
|
||||
sd = merge_lora_models_dim(ln, lr, dmax,settings,device,calc_precision)
|
||||
sd = merge_lora_models_dim(ln, lr, dmax,settings,device,calc_precision, device)
|
||||
else:
|
||||
sd = merge_lora_models(ln, lr, settings, False, calc_precision)
|
||||
sd = merge_lora_models(ln, lr, settings, False, calc_precision, device)
|
||||
|
||||
if os.path.isfile(filename) and not "overwrite" in settings:
|
||||
_err_msg = f"Output file ({filename}) existed and was not saved"
|
||||
print(_err_msg)
|
||||
return _err_msg
|
||||
else:
|
||||
a = merge_lora_models(ln[0:1], lr[0:1], settings, False, calc_precision)
|
||||
b = merge_lora_models(ln[1:2], lr[1:2], settings, False, calc_precision)
|
||||
a = merge_lora_models(ln[0:1], lr[0:1], settings, False, calc_precision, device)
|
||||
b = merge_lora_models(ln[1:2], lr[1:2], settings, False, calc_precision, device)
|
||||
sd = extract_two(a,b,alpha,beta,smooth)
|
||||
|
||||
# マージ後のメタデータを取得
|
||||
|
|
@ -457,7 +454,7 @@ def lmerge(loranames,loraratioss,settings,filename,dim,save_precision,calc_preci
|
|||
traceback.print_exc()
|
||||
return exc_value
|
||||
|
||||
def merge_lora_models(models, ratios, sets, locon, calc_precision):
|
||||
def merge_lora_models(models, ratios, sets, locon, calc_precision, device):
|
||||
base_alphas = {} # alpha for merged model
|
||||
base_dims = {}
|
||||
merge_dtype = str_to_dtype(calc_precision)
|
||||
|
|
@ -467,7 +464,7 @@ def merge_lora_models(models, ratios, sets, locon, calc_precision):
|
|||
keylist = LBLCOKS26
|
||||
|
||||
print(f"merging {model}: {ratios}")
|
||||
lora_sd, metadata, isv2 = load_state_dict(model, merge_dtype)
|
||||
lora_sd, metadata, isv2 = load_state_dict(model, merge_dtype, device)
|
||||
|
||||
# get alpha and dim
|
||||
alphas = {} # alpha for current model
|
||||
|
|
@ -489,8 +486,6 @@ def merge_lora_models(models, ratios, sets, locon, calc_precision):
|
|||
alpha = alphas[lora_module_name]
|
||||
|
||||
ratio = ratios[blockfromkey(key, keylist, isv2)]
|
||||
#print(key,blockfromkey(key, keylist, isv2))
|
||||
|
||||
if "same to Strength" in sets:
|
||||
ratio, fugou = (ratio ** 0.5, 1) if ratio > 0 else (abs(ratio) ** 0.5, -1)
|
||||
|
||||
|
|
@ -654,9 +649,9 @@ def extract_two(a,b,pa,pb,ps):
|
|||
|
||||
return merged_sd
|
||||
|
||||
def lycomerge(filename,ratios,calc_precision):
|
||||
def lycomerge(filename, ratios, calc_precision, device):
|
||||
merge_dtype = str_to_dtype(calc_precision)
|
||||
sd, metadata, isv2 = load_state_dict(filename, merge_dtype)
|
||||
sd, metadata, isv2 = load_state_dict(filename, merge_dtype, device)
|
||||
|
||||
if len(ratios) == 17:
|
||||
r0 = 1
|
||||
|
|
@ -756,7 +751,7 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
|
|||
|
||||
isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in theta_0.keys()
|
||||
isv2 = "cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight" in theta_0.keys()
|
||||
|
||||
|
||||
try:
|
||||
import networks
|
||||
is15 = True
|
||||
|
|
@ -783,14 +778,14 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
|
|||
checkpoint_info = sd_models.get_closet_checkpoint_match(model)
|
||||
if orig_checkpoint != checkpoint_info:
|
||||
sd_models.reload_model_weights(info=checkpoint_info)
|
||||
theta_0 = newpluslora(theta_0,filenames,lweis,names, isxl,isv2, keychanger)
|
||||
theta_0 = newpluslora(theta_0,filenames,lweis,names, calc_precision, isxl,isv2, keychanger)
|
||||
|
||||
if orig_checkpoint:
|
||||
sd_models.reload_model_weights(info=orig_checkpoint)
|
||||
else:
|
||||
for name,filename, lwei in zip(names,filenames, lweis):
|
||||
print(f"loading: {name}")
|
||||
lora_sd, metadata, isv2 = load_state_dict(filename, torch.float)
|
||||
lora_sd, metadata, isv2 = load_state_dict(filename, torch.float, device)
|
||||
|
||||
print(f"merging..." ,lwei)
|
||||
for key in lora_sd.keys():
|
||||
|
|
@ -859,8 +854,8 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
|
|||
gc.collect()
|
||||
return result + add
|
||||
|
||||
def newpluslora(theta_0,filenames,lweis,names, isxl,isv2, keychanger):
|
||||
nets.load_networks(names, [1]* len(names),[1]* len(names), [1]* len(names), isxl, isv2)
|
||||
def newpluslora(theta_0,filenames,lweis,names, calc_precision, isxl,isv2, keychanger):
|
||||
nets.load_networks(names, [1]* len(names), [1]* len(names), isxl = isxl, isv2 = isv2)
|
||||
|
||||
for l, loaded in enumerate(nets.loaded_networks):
|
||||
for n, name in enumerate(names):
|
||||
|
|
@ -869,7 +864,7 @@ def newpluslora(theta_0,filenames,lweis,names, isxl,isv2, keychanger):
|
|||
lbw(nets.loaded_networks[l],to26(lweis[n]),isv2)
|
||||
changed = True
|
||||
if not changed: "ERROR: {name}weight is not changed"
|
||||
|
||||
|
||||
for net in nets.loaded_networks:
|
||||
net.dyn_dim = None
|
||||
for name,module in tqdm(net.modules.items(), desc=f"{net.name}"):
|
||||
|
|
@ -1084,7 +1079,7 @@ def load_metadata_from_safetensors(safetensors_file: str) -> dict:
|
|||
metadata = {}
|
||||
return metadata
|
||||
|
||||
def dimgetter(filename):
|
||||
def dimgetter(filename, device = "cpu"):
|
||||
lora_sd = load_state_header(filename, torch.float)
|
||||
alpha = None
|
||||
dim = None
|
||||
|
|
@ -1093,13 +1088,13 @@ def dimgetter(filename):
|
|||
if "lora_unet_down_blocks_0_resnets_0_conv1.lora_down.weight" in lora_sd.keys():
|
||||
ltype = "LoCon"
|
||||
if type(lora_sd["lora_unet_down_blocks_0_resnets_0_conv1.lora_down.weight"]) is dict:
|
||||
lora_sd, _, _ = load_state_dict(filename, torch.float)
|
||||
lora_sd, _, _ = load_state_dict(filename, torch.float, device)
|
||||
_, _, dim, _ = dimalpha(lora_sd)
|
||||
|
||||
if "lora_unet_input_blocks_4_1_transformer_blocks_1_attn1_to_k.lora_down.weight" in lora_sd.keys():
|
||||
sdx = "XL"
|
||||
if type(lora_sd["lora_unet_input_blocks_4_1_transformer_blocks_1_attn1_to_k.lora_down.weight"]) is dict:
|
||||
lora_sd, _, _ = load_state_dict(filename, torch.float)
|
||||
lora_sd, _, _ = load_state_dict(filename, torch.float, device)
|
||||
_, _, dim, _ = dimalpha(lora_sd)
|
||||
elif "lora_unet_input_blocks_4_1_transformer_blocks_1_attn1_to_k.hada_w1_a" in lora_sd.keys():
|
||||
sdx = "XL"
|
||||
|
|
@ -1144,12 +1139,10 @@ def blockfromkey(key,keylist,isv2 = False):
|
|||
fullkey = fullkey.replace("lora_unet", "diffusion_model")
|
||||
elif "lora_te1_text_model" in fullkey:
|
||||
fullkey = fullkey.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||
|
||||
if "1_model_transformer_resblocks_" in fullkey:return 0
|
||||
|
||||
|
||||
for i,n in enumerate(keylist):
|
||||
if n in fullkey: return i
|
||||
|
||||
if n in fullkey: return i
|
||||
if "1_model_transformer_resblocks_" in fullkey:return 0
|
||||
print(f"ERROR:Block is not deteced:{fullkey}")
|
||||
return 0
|
||||
|
||||
|
|
@ -1578,4 +1571,4 @@ def load_model(checkpoint_info, reload = False):
|
|||
model_data.forge_loading_parameters = forge_model_params
|
||||
forge_model_reload()
|
||||
else:
|
||||
sd_models.load_model(checkpoint_info)
|
||||
sd_models.load_model(checkpoint_info)
|
||||
Loading…
Reference in New Issue