From 7988953b2c3f384890d5fd08e2ca2d66f4002e54 Mon Sep 17 00:00:00 2001 From: bbc_mc Date: Sun, 15 Jan 2023 22:00:00 +0900 Subject: [PATCH] Some fix with new hash256 system add: log file now save sha256 fix: log file accept new model name component CheckpointInfo fix: log file now save preset name correctly --- scripts/mbw/ui_mbw.py | 32 +++++++++++++++++++++++++-- scripts/mbw_each/ui_mbw_each.py | 35 +++++++++++++++++++++++++++--- scripts/mbw_util/merge_history.py | 36 +++++++++++++++++++++++++------ 3 files changed, 91 insertions(+), 12 deletions(-) diff --git a/scripts/mbw/ui_mbw.py b/scripts/mbw/ui_mbw.py index a1a5027..ac71566 100644 --- a/scripts/mbw/ui_mbw.py +++ b/scripts/mbw/ui_mbw.py @@ -4,6 +4,11 @@ import re from modules import sd_models, shared from tqdm import tqdm +try: + from modules import hashes + from modules.sd_models import CheckpointInfo +except: + pass from scripts.mbw.merge_block_weighted import merge from scripts.mbw_util.preset_weights import PresetWeights @@ -185,14 +190,37 @@ def on_ui_tabs(): # save log to history.tsv sd_models.list_models() + model_A_info = sd_models.get_closet_checkpoint_match(model_A) + model_B_info = sd_models.get_closet_checkpoint_match(model_B) model_O_info = sd_models.get_closet_checkpoint_match(os.path.basename(_output)) - model_O_hash = "" if model_O_info is None else model_O_info.hash + if hasattr(model_O_info, "sha256") and model_O_info.sha256 is None: + model_O_info:CheckpointInfo = model_O_info + model_O_info.sha256 = hashes.sha256(model_O_info.filename, "checkpoint/" + model_O_info.title) _names = presetWeights.find_names_by_weight(_weights) if _names and len(_names) > 0: weight_name = _names[0] else: weight_name = "" - mergeHistory.add_history(model_A, model_B, model_O, model_O_hash, sl_base_alpha, _weights, "", weight_name) + + def model_name(model_info): + return model_info.name if hasattr(model_info, "name") else model_info.title + def model_sha256(model_info): + return model_info.sha256 if hasattr(model_info, "sha256") else "" + mergeHistory.add_history( + model_name(model_A_info), + model_A_info.hash, + model_sha256(model_A_info), + model_name(model_B_info), + model_B_info.hash, + model_sha256(model_B_info), + model_name(model_O_info), + model_O_info.hash, + model_sha256(model_O_info), + sl_base_alpha, + _weights, + "", + weight_name + ) return gr.update(value=f"{ret_html}") btn_do_merge_block_weighted.click( diff --git a/scripts/mbw_each/ui_mbw_each.py b/scripts/mbw_each/ui_mbw_each.py index 3066d0a..2a3f366 100644 --- a/scripts/mbw_each/ui_mbw_each.py +++ b/scripts/mbw_each/ui_mbw_each.py @@ -4,6 +4,11 @@ import re from modules import sd_models, shared from tqdm import tqdm +try: + from modules import hashes + from modules.sd_models import CheckpointInfo +except: + pass from scripts.mbw_each.merge_block_weighted_mod import merge from scripts.mbw_util.preset_weights import PresetWeights @@ -397,14 +402,38 @@ def on_ui_tabs(): # save log to history.tsv sd_models.list_models() + model_A_info = sd_models.get_closet_checkpoint_match(model_0) + model_B_info = sd_models.get_closet_checkpoint_match(model_1) model_O_info = sd_models.get_closet_checkpoint_match(os.path.basename(output_file)) - model_O_hash = "" if model_O_info is None else model_O_info.hash - _names = presetWeights.find_names_by_weight(weight_A) + if hasattr(model_O_info, "sha256") and model_O_info.sha256 is None: + model_O_info:CheckpointInfo = model_O_info + model_O_info.sha256 = hashes.sha256(model_O_info.filename, "checkpoint/" + model_O_info.title) + _names = presetWeights.find_names_by_weight(weight_B) if _names and len(_names) > 0: weight_name = _names[0] else: weight_name = "" - mergeHistory.add_history(model_0, model_1, model_O, model_O_hash, base_alpha, weight_A, weight_B, weight_name) + + def model_name(model_info): + return model_info.name if hasattr(model_info, "name") else model_info.title + def model_sha256(model_info): + return model_info.sha256 if hasattr(model_info, "sha256") else "" + mergeHistory.add_history( + model_name(model_A_info), + model_A_info.hash, + model_sha256(model_A_info), + model_name(model_B_info), + model_B_info.hash, + model_sha256(model_B_info), + model_name(model_O_info), + model_O_info.hash, + model_sha256(model_O_info), + base_alpha, + weight_A, + weight_B, + weight_name + ) + return ret_html btn_clear_weighted.click( diff --git a/scripts/mbw_util/merge_history.py b/scripts/mbw_util/merge_history.py index eabd16a..6bad8ea 100644 --- a/scripts/mbw_util/merge_history.py +++ b/scripts/mbw_util/merge_history.py @@ -4,30 +4,49 @@ import os import datetime from csv import DictWriter, DictReader +import shutil from modules import scripts +CSV_FILE_ROOT = "csv/" CSV_FILE_PATH = "csv/history.tsv" -HEADERS = ["model_A", "model_A_hash", "model_B", "model_B_hash", "model_O", "model_O_hash", "base_alpha", "weight_name", "weight_values", "weight_values2", "datetime"] +HEADERS = [ + "model_A", "model_A_hash", "model_A_sha256", + "model_B", "model_B_hash", "model_B_sha256", + "model_O", "model_O_hash", "model_O_sha256", + "base_alpha", "weight_name", "weight_values", "weight_values2", "datetime"] path_root = scripts.basedir() class MergeHistory(): def __init__(self): + self.fileroot = os.path.join(path_root, CSV_FILE_ROOT) self.filepath = os.path.join(path_root, CSV_FILE_PATH) + if not os.path.exists(self.fileroot): + os.mkdir(self.fileroot) if os.path.exists(self.filepath): self.update_header() - def add_history(self, model_A, model_B, model_O, model_O_hash, sl_base_alpha, weight_value_A, weight_value_B, weight_name=""): + def add_history(self, + model_A_name, model_A_hash, model_A_sha256, + model_B_name, model_B_hash, model_B_sha256, + model_O_name, model_O_hash, model_O_sha256, + sl_base_alpha, + weight_value_A, + weight_value_B, + weight_name=""): _history_dict = {} _history_dict.update({ - "model_A": f"{os.path.basename(model_A.split(' ')[0])}", - "model_A_hash": f"{model_A.split(' ')[1]}", - "model_B": f"{os.path.basename(model_B.split(' ')[0])}", - "model_B_hash": f"{model_B.split(' ')[1]}", - "model_O": model_O, + "model_A": model_A_name, + "model_A_hash": model_A_hash, + "model_A_sha256": model_A_sha256, + "model_B": model_B_name, + "model_B_hash": model_B_hash, + "model_B_sha256": model_B_sha256, + "model_O": model_O_name, "model_O_hash": model_O_hash, + "model_O_sha256": model_O_sha256, "base_alpha": sl_base_alpha, "weight_name": weight_name, "weight_values": weight_value_A, @@ -54,7 +73,10 @@ class MergeHistory(): if len(new_header) > 0: # need update. hist_data = [ x for x in dr] + # apply change if len(hist_data) > 0: + # backup before change + shutil.copy(self.filepath, self.filepath + ".bak") with open(self.filepath, "w", newline="", encoding="utf-8") as f: dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t') dw.writeheader()