ver21
parent
f6bf9b5c24
commit
c69467885e
|
|
@ -7,8 +7,8 @@ def make_weight_cp(t, wa, wb):
|
|||
|
||||
|
||||
def rebuild_conventional(up, down, shape, dyn_dim=None):
|
||||
up = up.reshape(up.size(0), -1)
|
||||
down = down.reshape(down.size(0), -1)
|
||||
up = cpufloat(up.reshape(up.size(0), -1))
|
||||
down = cpufloat(down.reshape(down.size(0), -1))
|
||||
if dyn_dim is not None:
|
||||
up = up[:, :dyn_dim]
|
||||
down = down[:dyn_dim, :]
|
||||
|
|
@ -16,8 +16,9 @@ def rebuild_conventional(up, down, shape, dyn_dim=None):
|
|||
|
||||
|
||||
def rebuild_cp_decomposition(up, down, mid):
|
||||
up = up.reshape(up.size(0), -1)
|
||||
down = down.reshape(down.size(0), -1)
|
||||
up = cpufloat(up.reshape(up.size(0), -1))
|
||||
down = cpufloat(down.reshape(down.size(0), -1))
|
||||
mid = cpufloat(mid)
|
||||
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
||||
|
||||
|
||||
|
|
@ -66,3 +67,6 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
|||
n, m = m, n
|
||||
return m, n
|
||||
|
||||
def cpufloat(module):
|
||||
if not module: return module #None対策
|
||||
return module.to(torch.float) if module.device.type == "cpu" else module
|
||||
|
|
@ -128,13 +128,10 @@ def smergegen(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,m
|
|||
|
||||
save = True if SAVEMODES[0] in save_sets else False
|
||||
|
||||
if not forge:
|
||||
result = savemodel(theta_0,currentmodel,custom_name,save_sets,metadata) if save else "Merged model loaded:"+currentmodel
|
||||
result = savemodel(theta_0,currentmodel,custom_name,save_sets,metadata) if save else "Merged model loaded:"+currentmodel
|
||||
|
||||
model_loader(checkpoint_info, theta_0, metadata, currentmodel)
|
||||
|
||||
if forge and save:
|
||||
result = forge_save(custom_name if custom_name else currentmodel.replace(" ","").replace(",","_").replace("(","_").replace(")","_"))
|
||||
|
||||
cachedealer(False)
|
||||
|
||||
|
|
@ -308,7 +305,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
if not(len(weights_b) == 25 or len(weights_b) == 19 or len(weights_a) == 60): return f"ERROR: weights beta value must be 20 or 26 or 61.",*NON4
|
||||
|
||||
caster("model load start",hearm)
|
||||
printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems)
|
||||
printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks['ceed'],fine,inex,ex_blocks,ex_elems,device)
|
||||
|
||||
theta_1=load_model_weights_m(model_b,2,cachetarget,device).copy()
|
||||
|
||||
|
|
@ -417,6 +414,7 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
|
||||
theta_0[key] = theta_0[key].to(device)
|
||||
theta_1[key] = theta_1[key].to(device)
|
||||
|
||||
try:
|
||||
theta_2[key] = theta_2[key].to(device)
|
||||
except Exception as e:
|
||||
|
|
@ -567,12 +565,14 @@ def smerge(weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode
|
|||
if flux and not calcmode == "smoothAdd MT":
|
||||
theta_1[key] = None
|
||||
del theta_1[key]
|
||||
|
||||
theta_0[key] = theta_0[key].to("cpu")
|
||||
try:
|
||||
theta_1[key] = theta_1[key].to("cpu")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
#flux
|
||||
if qtype[0]:
|
||||
dellist = []
|
||||
|
|
@ -950,9 +950,9 @@ def elementals(key,weight_index,deep,randomer,num,lucks,deepprint,current_alpha)
|
|||
|
||||
def forkforker(filename,device):
|
||||
if forge:
|
||||
return load_torch_file(filename)
|
||||
return load_torch_file(filename, device = torch.device(device))
|
||||
try:
|
||||
return sd_models.read_state_dict(filename,map_location = device)
|
||||
return sd_models.read_state_dict(filename, map_location = device)
|
||||
except:
|
||||
return sd_models.read_state_dict(filename)
|
||||
|
||||
|
|
@ -1550,7 +1550,7 @@ def getcachelist():
|
|||
################################################
|
||||
##### print
|
||||
|
||||
def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks,fine,inex,ex_blocks,ex_elems):
|
||||
def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,mode,useblocks,calcmode,deep,lucks,fine,inex,ex_blocks,ex_elems,device):
|
||||
print(f" model A \t: {model_a}")
|
||||
print(f" model B \t: {model_b}")
|
||||
print(f" model C \t: {model_c}")
|
||||
|
|
@ -1564,6 +1564,7 @@ def printstart(model_a,model_b,model_c,base_alpha,base_beta,weights_a,weights_b,
|
|||
print(f" Weights Seed\t: {lucks}")
|
||||
print(f" {inex} \t: {ex_blocks,ex_elems}")
|
||||
print(f" Adjust \t: {fine}")
|
||||
print(f" Device \t: {device}")
|
||||
|
||||
def caster(news,hear):
|
||||
if hear: print(news)
|
||||
|
|
|
|||
|
|
@ -296,6 +296,7 @@ def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,save_precision,calc_
|
|||
except:
|
||||
currentinfo = None
|
||||
|
||||
lowvram.module_in_gpu = None #web-uiのバグ対策
|
||||
|
||||
checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
|
||||
load_model(checkpoint_info)
|
||||
|
|
|
|||
Loading…
Reference in New Issue