pull/464/head
parent
3a70561487
commit
d5c78c361d
|
|
@ -117,6 +117,18 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
|
|||
else:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||
|
||||
if match(m, r"lora_te1_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
return f"clip_l_transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||
|
||||
if match(m, r"lora_te3_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
if 'mlp_fc1' in m[1]:
|
||||
return f"t5xxl_transformer_text_model_encoder_layers_{m[0]}layer{m[1]}SelfAttention_k"
|
||||
|
||||
match_pattern = re.match(r"lora_te3_encoder_block_(\d+)_layer_(\d+)_(.+)", key)
|
||||
if match_pattern:
|
||||
block_num, layer_num, suffix = match_pattern.groups()
|
||||
return f"t5xxl_transformer_encoder_block_{block_num}_layer_{layer_num}_{suffix}"
|
||||
|
||||
return key
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ def on_ui_tabs():
|
|||
sml_loratypes = gr.CheckboxGroup(show_label=False, choices= ["LoRA", "LoCon", "Others"], value=["LoRA", "LoCon", "Others"])
|
||||
sml_dims = gr.CheckboxGroup(label = "1.X/2.X",choices=[],value = [],type="value",interactive=True,visible = False)
|
||||
sml_dims_xl = gr.CheckboxGroup(label = "XL",choices=[],value = [],type="value",interactive=True,visible = False)
|
||||
sml_dims_flux = gr.CheckboxGroup(label = "Flux",choices=[],value = [],type="value",interactive=True,visible = False)
|
||||
with gr.Row(equal_height=False):
|
||||
sml_calcdim = gr.Button(elem_id="calcloras", value="Calculate LoRA dimensions",variant='primary')
|
||||
sml_calcsets = gr.CheckboxGroup(choices=["Save as CSV","Load from CSV"],show_label=False)
|
||||
|
|
@ -255,21 +256,28 @@ def on_ui_tabs():
|
|||
|
||||
global selectable
|
||||
selectable = toselect(ldict)
|
||||
return (gr.update(choices=selectable, value=[]), gr.update(visible=True, choices=makedimlist("1.X/2.X")),
|
||||
gr.update(visible=True, choices=makedimlist("XL")))
|
||||
return (gr.update(choices=selectable, value=[]),
|
||||
gr.update(visible=True, choices=makedimlist("1.X/2.X")),
|
||||
gr.update(visible=True, choices=makedimlist("XL")),
|
||||
gr.update(visible=True, choices=makedimlist("Flux"))
|
||||
)
|
||||
|
||||
sml_calcdim.click(
|
||||
fn=calculatedim,
|
||||
inputs=[sml_calcsets, device],
|
||||
outputs=[sml_loras,sml_dims,sml_dims_xl]
|
||||
outputs=[sml_loras,sml_dims,sml_dims_xl,sml_dims_flux]
|
||||
)
|
||||
|
||||
def dimselector(dims, dims_xl, ltypes):
|
||||
def dimselector(dims, dims_xl, dims_flux, ltypes):
|
||||
rl={}
|
||||
if "Others" in ltypes:ltypes += ["LyCORIS", "unknown"]
|
||||
for name, vals in ldict.items():
|
||||
dim, ltype, sdver = vals
|
||||
if (dim in dims if sdver == "1.X/2.X" else dim in dims_xl) and ltype in ltypes:
|
||||
if sdver == "1.X/2.X" and dim in dims and ltype in ltypes:
|
||||
rl[name] = vals
|
||||
if sdver == "XL" and dim in dims_xl and ltype in ltypes:
|
||||
rl[name] = vals
|
||||
if sdver == "Flux" and dim in dims_flux and ltype in ltypes:
|
||||
rl[name] = vals
|
||||
|
||||
global selectable
|
||||
|
|
@ -287,9 +295,10 @@ def on_ui_tabs():
|
|||
|
||||
hidenb.change(fn=lambda x: False, outputs = [hidenb])
|
||||
sml_loras.change(fn=llister,inputs=[sml_loras,sml_lratio, hidenb],outputs=[sml_loranames])
|
||||
sml_dims.change(fn=dimselector,inputs=[sml_dims,sml_dims_xl,sml_loratypes],outputs=[sml_loras])
|
||||
sml_dims_xl.change(fn=dimselector,inputs=[sml_dims,sml_dims_xl,sml_loratypes],outputs=[sml_loras])
|
||||
sml_loratypes.change(fn=dimselector,inputs=[sml_dims,sml_dims_xl,sml_loratypes],outputs=[sml_loras])
|
||||
sml_dims.change(fn=dimselector,inputs=[sml_dims,sml_dims_xl,sml_dims_flux,sml_loratypes],outputs=[sml_loras])
|
||||
sml_dims_xl.change(fn=dimselector,inputs=[sml_dims,sml_dims_xl,sml_dims_flux,sml_loratypes],outputs=[sml_loras])
|
||||
sml_dims_flux.change(fn=dimselector,inputs=[sml_dims,sml_dims_xl,sml_dims_flux,sml_loratypes],outputs=[sml_loras])
|
||||
sml_loratypes.change(fn=dimselector,inputs=[sml_dims,sml_dims_xl,sml_dims_flux,sml_loratypes],outputs=[sml_loras])
|
||||
|
||||
##############################################################
|
||||
####### make LoRA from checkpoint
|
||||
|
|
@ -441,10 +450,10 @@ def lmerge(loranames,loraratioss,settings,filename,dim,save_precision,calc_preci
|
|||
sd = lycomerge(ln[0], lr[0], calc_precision, device)
|
||||
elif dim > 0:
|
||||
print("change demension to ", dim)
|
||||
sd = merge_lora_models_dim(ln, lr, dim,settings,device,calc_precision, device)
|
||||
sd = merge_lora_models_dim(ln, lr, dim,settings,device,calc_precision)
|
||||
elif auto and ld.count(ld[0]) != len(ld):
|
||||
print("change demension to ",dmax)
|
||||
sd = merge_lora_models_dim(ln, lr, dmax,settings,device,calc_precision, device)
|
||||
sd = merge_lora_models_dim(ln, lr, dmax,settings,device,calc_precision)
|
||||
else:
|
||||
sd = merge_lora_models(ln, lr, settings, False, calc_precision, device)
|
||||
|
||||
|
|
@ -789,15 +798,16 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
|
|||
|
||||
keychanger = {}
|
||||
for key in theta_0.keys():
|
||||
if "model" in key:
|
||||
skey = key.replace(".","_").replace("_weight","")
|
||||
if "conditioner_embedders_" in skey:
|
||||
keychanger[skey.split("conditioner_embedders_",1)[1]] = key
|
||||
else:
|
||||
if "wrapped" in skey:
|
||||
keychanger[skey.split("wrapped_",1)[1]] = key
|
||||
else:
|
||||
keychanger[skey.split("model_",1)[1]] = key
|
||||
skey = key.replace(".","_").replace("_weight","")
|
||||
if "conditioner_embedders_" in skey:
|
||||
keychanger[skey.split("conditioner_embedders_",1)[1]] = key
|
||||
else:
|
||||
if "wrapped" in skey:
|
||||
keychanger[skey.split("wrapped_",1)[1]] = key
|
||||
elif "clip_l" in skey or "t5xxl" in skey:
|
||||
keychanger[skey.replace("text_encoders_","")] = key
|
||||
elif "model_" in skey:
|
||||
keychanger[skey.split("model_",1)[1]] = key
|
||||
|
||||
lowvram.module_in_gpu = None #web-uiのバグ対策
|
||||
|
||||
|
|
@ -922,6 +932,7 @@ def newpluslora(theta_0,filenames,lweis,names, calc_precision,isxl,isv2,isflux,
|
|||
for name,module in tqdm(net.modules.items(), desc=f"{net.name}"):
|
||||
fullkey = convert_diffusers_name_to_compvis(name,isv2)
|
||||
msd_key = fullkey.split(".")[0]
|
||||
|
||||
if isxl:
|
||||
if "lora_unet" in msd_key:
|
||||
msd_key = msd_key.replace("lora_unet", "diffusion_model")
|
||||
|
|
@ -1662,7 +1673,7 @@ def get_flux_blocks(key):
|
|||
return "VAE"
|
||||
if "t5xxl" in key:
|
||||
return "T5"
|
||||
if "text_encoders.clip" in key:
|
||||
if "clip_l" in key:
|
||||
return "CLIP"
|
||||
|
||||
match = re.search(r'\_(\d+)\_', key)
|
||||
|
|
|
|||
Loading…
Reference in New Issue