Merge pull request #203 from wkpark/fix-isxl

auto detect SDXL
ver15
hako-mikan 2023-09-02 19:08:34 +09:00 committed by GitHub
commit ad96ef008b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 28 additions and 1 deletions

View File

@ -206,7 +206,7 @@ def on_ui_tabs():
with gr.Row():
dd_preset_weight = gr.Dropdown(label="Load preset", choices=preset_name_list(weights_presets), interactive=True, elem_id="refresh_presets")
preset_refresh = gr.Button(value='\U0001f504', elem_classes=["tool"])
isxl = gr.Radio(label = "type",choices = ["1.X or 2.X", "XL"], value = "1.X or 2.X", type="index")
isxl = gr.Radio(label = "type",choices = ["1.X or 2.X", "XL"], value = "1.X or 2.X", type="index", visible=False)
with gr.Column():
with gr.Row():
dd_preset_weight_r = gr.Dropdown(label="Load Romdom preset", choices=preset_name_list(weights_presets,True), interactive=True, elem_id="refresh_presets")
@ -437,6 +437,8 @@ def on_ui_tabs():
setbeta.click(fn=slider2text,inputs=[*menbers,wpresets, dd_preset_weight,isxl],outputs=[weights_b])
setx.click(fn=add_to_seq,inputs=[xgrid,weights_a],outputs=[xgrid])
model_a.change(fn=lambda model: gr.update(value="XL" if is_xl(model) is True else "1.X or 2.X", visible=True if is_xl(model) is None else False), inputs=[model_a], outputs=[isxl])
def addblockweights(val, blockopt, *blocks):
if val == "none":
val = 0
@ -607,6 +609,31 @@ def loadmetadata(model):
if sdict == {}: return "no metadata"
return json.dumps(sdict,indent=4)
def get_safetensors_header(filename):
with open(filename, mode="rb") as file:
metadata_len = file.read(8)
metadata_len = int.from_bytes(metadata_len, "little")
json_start = file.read(2)
if metadata_len > 2 and json_start in (b'{"', b"{'"):
json_data = json_start + file.read(metadata_len-2)
return json.loads(json_data)
# invalid safetensors
return None
def is_xl(modelname):
checkpointinfo = sd_models.get_closet_checkpoint_match(modelname)
if checkpointinfo is None:
return None
header = get_safetensors_header(checkpointinfo.filename)
if header is not None:
if "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in header:
return True
return False
return None
def load_historyf(data, count=20, reload=False):
filepath = os.path.join(path_root,"mergehistory.csv")
global mlist,msearch