hako-mikan 2025-01-15 00:30:08 +09:00
parent f6bf9b5c24
commit c69467885e
3 changed files with 18 additions and 12 deletions

View File

@ -7,8 +7,8 @@ def make_weight_cp(t, wa, wb):
def rebuild_conventional(up, down, shape, dyn_dim=None):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
up = cpufloat(up.reshape(up.size(0), -1))
down = cpufloat(down.reshape(down.size(0), -1))
if dyn_dim is not None:
up = up[:, :dyn_dim]
down = down[:dyn_dim, :]
@ -16,8 +16,9 @@ def rebuild_conventional(up, down, shape, dyn_dim=None):
def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
up = cpufloat(up.reshape(up.size(0), -1))
down = cpufloat(down.reshape(down.size(0), -1))
mid = cpufloat(mid)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
@ -66,3 +67,6 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
n, m = m, n
return m, n
def cpufloat(module):
if not module: return module #None対策
return module.to(torch.float) if module.device.type == "cpu" else module

View File

@ -128,13 +128,10 @@ def smergegen(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,m
save = True if SAVEMODES[0] in save_sets else False
if not forge:
result = savemodel(theta_0,currentmodel,custom_name,save_sets,metadata) if save else "Merged model loaded:"+currentmodel
result = savemodel(theta_0,currentmodel,custom_name,save_sets,metadata) if save else "Merged model loaded:"+currentmodel
model_loader(checkpoint_info, theta_0, metadata, currentmodel)
if forge and save:
result = forge_save(custom_name if custom_name else currentmodel.replace(" ","").replace(",","_").replace("(","_").replace(")","_"))
cachedealer(False)
@ -308,7 +305,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
if not(len(weights_b) == 25 or len(weights_b) == 19 or len(weights_a) == 60): return f"ERROR: weights beta value must be 20 or 26 or 61.",*NON4
caster("model load start",hearm)
printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems)
printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems,device)
theta_1=load_model_weights_m(model_b,2,cachetarget,device).copy()
@ -417,6 +414,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
theta_0[key] = theta_0[key].to(device)
theta_1[key] = theta_1[key].to(device)
try:
theta_2[key] = theta_2[key].to(device)
except Exception as e:
@ -567,12 +565,14 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
if flux and not calcmode == "smoothAdd MT":
theta_1[key] = None
del theta_1[key]
theta_0[key] = theta_0[key].to("cpu")
try:
theta_1[key] = theta_1[key].to("cpu")
except:
pass
#flux
if qtype[0]:
dellist = []
@ -950,9 +950,9 @@ def elementals(key,weight_index,deep,randomer,num,lucks,deepprint,current_alpha)
def forkforker(filename,device):
if forge:
return load_torch_file(filename)
return load_torch_file(filename, device = torch.device(device))
try:
return sd_models.read_state_dict(filename,map_location = device)
return sd_models.read_state_dict(filename, map_location = device)
except:
return sd_models.read_state_dict(filename)
@ -1550,7 +1550,7 @@ def getcachelist():
################################################
##### print
def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks,fine,inex,ex_blocks,ex_elems):
def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks,fine,inex,ex_blocks,ex_elems,device):
print(f" model A \t: {model_a}")
print(f" model B \t: {model_b}")
print(f" model C \t: {model_c}")
@ -1564,6 +1564,7 @@ def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,
print(f" Weights Seed\t: {lucks}")
print(f" {inex} \t: {ex_blocks,ex_elems}")
print(f" Adjust \t: {fine}")
print(f" Device \t: {device}")
def caster(news,hear):
if hear: print(news)

View File

@ -296,6 +296,7 @@ def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,save_precision,calc_
except:
currentinfo = None
lowvram.module_in_gpu = None #web-uiのバグ対策
checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
load_model(checkpoint_info)