Add support for "save safetensors" and "save as fp16"

add: support safetensors
add: support save as fp16
refact: some fix
fix_update
bbc_mc 2023-01-09 00:00:00 +09:00
parent 6bbfdb35d0
commit a198c354ff
4 changed files with 137 additions and 50 deletions

View File

@ -26,7 +26,11 @@ def dprint(str, flg):
print(str)
def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_file="", allow_overwrite=False, verbose=False):
def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5,
output_file="", allow_overwrite=False, verbose=False,
save_as_safetensors=False,
save_as_half=False
):
if weights is None:
weights = None
else:
@ -38,7 +42,7 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_f
device = device if device in ["cpu", "cuda"] else "cpu"
def load_model(_model, _device):
def load_model(_model, _device="cpu"):
model_info = sd_models.get_closet_checkpoint_match(_model)
if model_info:
model_file = model_info.filename
@ -51,10 +55,11 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_f
theta_1 = load_model(model_1, device)
alpha = base_alpha
_footer = "-half" if save_as_half else ""
_footer = f"{_footer}.safetensors" if save_as_safetensors else f"{_footer}.ckpt"
if not output_file or output_file == "":
output_file = f'bw-{model_0}-{model_1}-{str(alpha)[2:] + "0"}.ckpt'
else:
output_file = output_file if ".ckpt" in output_file else output_file + ".ckpt"
# check if output file already exists
if os.path.isfile(output_file) and not allow_overwrite:
@ -66,10 +71,12 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_f
re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
print(" merging ...")
dprint(f"-- start Stage 1/2 --", verbose)
count_target_of_basealpha = 0
for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys()):
for key in (tqdm(theta_0.keys(), desc="Stage 1/2")):
if "model" in key and key in theta_1:
dprint(f" key : {key}", verbose)
current_alpha = alpha
@ -109,20 +116,35 @@ def merge(weights:list, model_0, model_1, device="cpu", base_alpha=0.5, output_f
theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
else:
dprint(f" key - {key}", verbose)
dprint(f"-- start Stage 2/2 --", verbose)
for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
if "model" in key and key not in theta_0:
dprint(f" key : {key}", verbose)
theta_0.update({key:theta_1[key]})
if save_as_half:
theta_0[key] = theta_0[key].half()
else:
dprint(f" key - {key}", verbose)
print("Saving...")
torch.save({"state_dict": theta_0}, output_file)
_, extension = os.path.splitext(output_file)
if extension.lower() == ".safetensors" or save_as_safetensors:
if save_as_safetensors and extension.lower() != ".safetensors":
output_file = output_file + ".safetensors"
import safetensors.torch
safetensors.torch.save_file(theta_0, output_file, metadata={"format": "pt"})
else:
torch.save({"state_dict": theta_0}, output_file)
print("Done!")

View File

@ -17,18 +17,22 @@ def on_ui_tabs():
with gr.Column():
with gr.Row():
with gr.Column(variant="panel"):
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
btn_clear_weight = gr.Button(value="Clear values")
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
with gr.Row():
btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary")
btn_clear_weight = gr.Button(value="Clear values")
btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint")
html_output_block_weight_info = gr.HTML()
with gr.Column():
dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list())
txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25")
btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary")
with gr.Row():
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.01, value=1)
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.01, value=0)
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
with gr.Row():
chk_save_as_half = gr.Checkbox(label="Save as half", value=False)
chk_save_as_safetensors = gr.Checkbox(label="Save as safetensors", value=False)
with gr.Row():
model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles())
model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles())
@ -90,8 +94,13 @@ def on_ui_tabs():
sl_M_00,
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05,
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11,
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite,
chk_save_as_safetensors, chk_save_as_half
):
# debug output
print( "#### Merge Block Weighted ####")
_weights = ",".join(
[str(x) for x in [
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,
@ -104,7 +113,11 @@ def on_ui_tabs():
if not model_A or not model_B:
return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]")
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
#
# Prepare params before run merge
#
# generate output file name from param
model_A_info = sd_models.get_closet_checkpoint_match(model_A)
if model_A_info:
_model_A_name = model_A_info.model_name
@ -116,33 +129,47 @@ def on_ui_tabs():
else:
_model_B_info = ""
txt_model_O = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', txt_model_O)
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
if ".ckpt" not in model_O:
model_O = model_O + ".ckpt"
def validate_output_filename(output_filename, save_as_safetensors=False, save_as_half=False):
output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename)
filename_body, filename_ext = os.path.splitext(output_filename)
_ret = output_filename
_footer = "-half" if save_as_half else ""
if filename_ext in [".safetensors", ".ckpt"]:
_ret = f"{filename_body}{_footer}{filename_ext}"
elif save_as_safetensors:
_ret = f"{output_filename}{_footer}.safetensors"
else:
_ret = f"{output_filename}{_footer}.ckpt"
return _ret
model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O
model_O = validate_output_filename(model_O, save_as_safetensors=chk_save_as_safetensors, save_as_half=chk_save_as_half)
_output = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O)
_output = os.path.join(ckpt_dir, model_O)
# debug output
print( "#### Merge Block Weighted ####")
if not chk_allow_overwrite:
if os.path.exists(_output):
_err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort."
print(_err_msg)
return gr.update(value=f"{_err_msg} [{_output}]")
print(f"model_0 : {model_A}")
print(f"model_1 : {model_B}")
print(f"base_alpha : {sl_base_alpha}")
print(f"output_file: {_output}")
print(f"weights : {_weights}")
print(f" model_0 : {model_A}")
print(f" model_1 : {model_B}")
print(f" base_alpha : {sl_base_alpha}")
print(f" output_file: {_output}")
print(f" weights : {_weights}")
result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite,
base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw
base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw,
save_as_safetensors=chk_save_as_safetensors,
save_as_half=chk_save_as_half,
)
if result:
ret_html = "merged.<br>" + f"{model_A}<br>" + f"{model_B}<br>" + f"{model_O}"
print("merged.")
else:
ret_html = ret_message
print("merge failed.")
# save log to history.tsv
sd_models.list_models()
@ -160,7 +187,8 @@ def on_ui_tabs():
fn=onclick_btn_do_merge_block_weighted,
inputs=[model_A, model_B]
+ sl_IN + sl_MID + sl_OUT
+ [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
+ [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite]
+ [chk_save_as_safetensors, chk_save_as_half],
outputs=[html_output_block_weight_info]
)
@ -178,7 +206,7 @@ def on_ui_tabs():
def on_change_dd_preset_weight(dd_preset_weight):
_weights = presetWeights.find_weight_by_name(dd_preset_weight)
_ret = on_btn_apply_block_weithg_from_txt(_weights)
_ret = on_btn_apply_block_weight_from_txt(_weights)
return [gr.update(value=_weights)] + _ret
dd_preset_weight.change(
fn=on_change_dd_preset_weight,
@ -201,7 +229,7 @@ def on_ui_tabs():
outputs=[model_A, model_B]
)
def on_btn_apply_block_weithg_from_txt(txt_block_weight):
def on_btn_apply_block_weight_from_txt(txt_block_weight):
if not txt_block_weight or txt_block_weight == "":
return [gr.update() for _ in range(25)]
_list = [x.strip() for x in txt_block_weight.split(",")]
@ -209,7 +237,7 @@ def on_ui_tabs():
return [gr.update() for _ in range(25)]
return [gr.update(value=x) for x in _list]
btn_apply_block_weithg_from_txt.click(
fn=on_btn_apply_block_weithg_from_txt,
fn=on_btn_apply_block_weight_from_txt,
inputs=[txt_block_weight],
outputs=[
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05,

View File

@ -27,7 +27,10 @@ def dprint(str, flg):
def merge(weight_A:list, weight_B:list, model_0, model_1, device="cpu", base_alpha=0.5,
output_file="", allow_overwrite=False, verbose=False):
output_file="", allow_overwrite=False, verbose=False,
save_as_safetensors=False,
save_as_half=False,
):
def _check_arg_weight(weight):
if weight is None:
@ -53,10 +56,11 @@ def merge(weight_A:list, weight_B:list, model_0, model_1, device="cpu", base_alp
device = device if device in ["cpu", "cuda"] else "cpu"
alpha = base_alpha
_footer = "-half" if save_as_half else ""
_footer = f"{_footer}.safetensors" if save_as_safetensors else f"{_footer}.ckpt"
if not output_file or output_file == "":
output_file = f'bw-{model_0}-{model_1}-{str(alpha)[2:] + "0"}.ckpt'
else:
output_file = output_file if ".ckpt" in output_file else output_file + ".ckpt"
output_file = f'bw-{model_0}-{model_1}-{str(alpha)[2:] + "0"}{_footer}'
# check if output file already exists
if os.path.isfile(output_file) and not allow_overwrite:
@ -139,7 +143,8 @@ def merge(weight_A:list, weight_B:list, model_0, model_1, device="cpu", base_alp
theta_0[key] = _var1 + _var2 + _var3
# theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
else:
dprint(f" key - {key}", verbose)
@ -148,12 +153,23 @@ def merge(weight_A:list, weight_B:list, model_0, model_1, device="cpu", base_alp
if "model" in key and key not in theta_0:
dprint(f" key : {key}", verbose)
theta_0.update({key:theta_1[key]})
if save_as_half:
theta_0[key] = theta_0[key].half()
else:
dprint(f" key - {key}", verbose)
print("Saving...")
torch.save({"state_dict": theta_0}, output_file)
_, extension = os.path.splitext(output_file)
if extension.lower() == ".safetensors" or save_as_safetensors:
if save_as_safetensors and extension.lower() != ".safetensors":
output_file = output_file + ".safetensors"
import safetensors.torch
safetensors.torch.save_file(theta_0, output_file, metadata={"format": "pt"})
else:
torch.save({"state_dict": theta_0}, output_file)
print("Done!")

View File

@ -32,6 +32,9 @@ def on_ui_tabs():
sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.01, value=0)
chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False)
chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False)
with gr.Row():
chk_save_as_half = gr.Checkbox(label="Save as half", value=False)
chk_save_as_safetensors = gr.Checkbox(label="Save as safetensors", value=False)
with gr.Row():
dd_model_A = gr.Dropdown(label="Model_A", choices=sd_models.checkpoint_tiles())
dd_model_B = gr.Dropdown(label="Model_B", choices=sd_models.checkpoint_tiles())
@ -157,7 +160,8 @@ def on_ui_tabs():
sl_M_B_00,
sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05,
sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11,
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite
txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite,
chk_save_as_safetensors, chk_save_as_half
):
base_alpha = sl_base_alpha
_weight_A = ",".join(
@ -178,7 +182,7 @@ def on_ui_tabs():
]])
# debug output
print( "#### Merge Block Weighted Each ####")
print( "#### Merge Block Weighted : Each ####")
if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "":
_err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]"
@ -200,7 +204,9 @@ def on_ui_tabs():
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O,
verbose=chk_verbose_mbw,
params=_items
params=_items,
save_as_safetensors=chk_save_as_safetensors,
save_as_half=chk_save_as_half
)
else:
_ret = f" multi-merge text found, but invalid params. skipped :[{_line}]"
@ -211,7 +217,9 @@ def on_ui_tabs():
ret_html += _run_merge(
weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B,
allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O,
verbose=chk_verbose_mbw
verbose=chk_verbose_mbw,
save_as_safetensors=chk_save_as_safetensors,
save_as_half=chk_save_as_half
)
sd_models.list_models()
@ -222,20 +230,28 @@ def on_ui_tabs():
fn=onclick_btn_do_merge_block_weighted,
inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd]
+ sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT
+ [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite],
+ [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite]
+ [chk_save_as_safetensors, chk_save_as_half],
outputs=[html_output_block_weight_info]
)
def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0,
model_Output="", verbose=False, params=[]
model_Output="", verbose=False, params=[],
save_as_safetensors=False,
save_as_half=False,
):
def validate_output_filename(output_filename):
def validate_output_filename(output_filename, save_as_safetensors=False, save_as_half=False):
output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename)
_, ext = os.path.splitext(output_filename)
filename_body, filename_ext = os.path.splitext(output_filename)
_ret = output_filename
if ext != ".safetensors" and ext != ".ckpt":
_ret = f"{output_filename}.ckpt"
_footer = "-half" if save_as_half else ""
if filename_ext in [".safetensors", ".ckpt"]:
_ret = f"{filename_body}{_footer}{filename_ext}"
elif save_as_safetensors:
_ret = f"{output_filename}{_footer}.safetensors"
else:
_ret = f"{output_filename}{_footer}.ckpt"
return _ret
model_O = ""
@ -284,7 +300,7 @@ def on_ui_tabs():
elif _item_l.upper() == "O":
if _item_r.strip() != "":
_ret = validate_output_filename(_item_r.strip())
_ret = validate_output_filename(_item_r.strip(), save_as_safetensors=save_as_safetensors, save_as_half=save_as_half)
print(f" * Output filename changed:[{model_O}] -> [{_ret}]")
model_O = _ret
@ -328,8 +344,8 @@ def on_ui_tabs():
if model_O == "":
_a = os.path.splitext(os.path.basename(_model_A_name))[0]
_b = os.path.splitext(os.path.basename(_model_B_name))[0]
model_O = f"bw-merge-{_a}-{_b}-{base_alpha}.ckpt" if model_Output == "" else model_Output
model_O = validate_output_filename(model_O)
model_O = f"bw-merge-{_a}-{_b}-{base_alpha}" if model_Output == "" else model_Output
model_O = validate_output_filename(model_O, save_as_safetensors=save_as_safetensors, save_as_half=save_as_half)
output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O)
#
# Check params
@ -352,10 +368,15 @@ def on_ui_tabs():
print(f" output_file: {output_file}")
print(f" weight_A : {weight_A}")
print(f" weight_B : {weight_B}")
print(f" half : {save_as_half}")
result, ret_message = merge(
weight_A=weight_A, weight_B=weight_B, model_0=model_0, model_1=model_1,
allow_overwrite=allow_overwrite, base_alpha=base_alpha, output_file=output_file, verbose=verbose)
allow_overwrite=allow_overwrite, base_alpha=base_alpha, output_file=output_file,
verbose=verbose,
save_as_safetensors=save_as_safetensors,
save_as_half=save_as_half,
)
if result:
ret_html = f"merged. {model_0} + {model_1} = {model_O} <br>"
print("merged.")
@ -367,7 +388,7 @@ def on_ui_tabs():
# save log to history.tsv
sd_models.list_models()
model_O_info = sd_models.get_closet_checkpoint_match(os.path.basename(output_file))
model_O_hash = "" if not model_O_info else model_O_info.hash
model_O_hash = "" if model_O_info is None else model_O_info.hash
_names = presetWeights.find_names_by_weight(weight_A)
if _names and len(_names) > 0:
weight_name = _names[0]