diff --git a/modules/merging/merge.py b/modules/merging/merge.py index 6b00afe2b..fe0097da9 100644 --- a/modules/merging/merge.py +++ b/modules/merging/merge.py @@ -9,11 +9,12 @@ import safetensors.torch import torch from tqdm import tqdm +import modules.memstats from modules.shared import log from modules.merging import merge_methods -from modules.merging.utils import WeightClass -from modules.merging.model import SDModel -from modules.merging.rebasin import ( +from modules.merging.merge_utils import WeightClass +from modules.merging.merge_model import SDModel +from modules.merging.merge_rebasin import ( apply_permutation, sdunet_permutation_spec, update_model_a, @@ -95,8 +96,7 @@ def restore_sd_model(original_model: Dict, merged_model: Dict) -> Dict: def log_vram(txt=""): - alloc = torch.cuda.memory_allocated(0) - log.info(f"{txt} VRAM: {alloc*1e-9:5.3f}GB") + log.info(f"{txt} VRAM: {modules.memstats.memory_stats()}") def load_thetas( diff --git a/modules/merging/model.py b/modules/merging/merge_model.py similarity index 100% rename from modules/merging/model.py rename to modules/merging/merge_model.py diff --git a/modules/merging/presets.py b/modules/merging/merge_presets.py similarity index 100% rename from modules/merging/presets.py rename to modules/merging/merge_presets.py diff --git a/modules/merging/rebasin.py b/modules/merging/merge_rebasin.py similarity index 100% rename from modules/merging/rebasin.py rename to modules/merging/merge_rebasin.py diff --git a/modules/merging/utils.py b/modules/merging/merge_utils.py similarity index 95% rename from modules/merging/utils.py rename to modules/merging/merge_utils.py index 7feff1393..41b6f7442 100644 --- a/modules/merging/utils.py +++ b/modules/merging/merge_utils.py @@ -2,9 +2,12 @@ import inspect # import logging import re from modules.merging import merge_methods -from modules.merging.presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS +from modules.merging.merge_presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS + +ALL_PRESETS = {} +ALL_PRESETS.update(BLOCK_WEIGHTS_PRESETS) +ALL_PRESETS.update(SDXL_BLOCK_WEIGHTS_PRESETS) -BLOCK_WEIGHTS_PRESETS |= SDXL_BLOCK_WEIGHTS_PRESETS MERGE_METHODS = dict(inspect.getmembers(merge_methods, inspect.isfunction)) BETA_METHODS = [ name diff --git a/modules/ui_models.py b/modules/ui_models.py index 87c1c2898..a9c99e8e3 100644 --- a/modules/ui_models.py +++ b/modules/ui_models.py @@ -11,8 +11,8 @@ from modules.shared import opts, log, req import modules.errors import modules.hashes from modules.merging import merge_methods -from modules.merging.utils import BETA_METHODS, TRIPLE_METHODS, interpolate -from modules.merging.presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS +from modules.merging.merge_utils import BETA_METHODS, TRIPLE_METHODS, interpolate +from modules.merging.merge_presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS search_metadata_civit = None @@ -242,7 +242,7 @@ def create_ui(): with FormRow(): precision = gr.Radio(choices=["fp16", "fp32"], value="fp16", label="Model precision") with FormRow(): - device = gr.Radio(choices=["cpu", "cuda"], value="cpu", label="Device") + work_device = gr.Radio(choices=["cpu", "cuda"], value="cpu", label="Device") with FormRow(): bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", interactive=True, label="Bake in VAE") @@ -281,7 +281,7 @@ def create_ui(): prune, re_basin, re_basin_iterations, - device, + work_device, bake_in_vae): kwargs = {} for x in inspect.getfullargspec(MEHmodelmerger)[0]: @@ -290,8 +290,6 @@ def create_ui(): if kwargs[key] in [None, "None", "", 0, []]: del kwargs[key] - # return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"{kwargs}"] - try: results = extras.run_MEHmodelmerger(dummy_component, **kwargs) except Exception as e: @@ -386,7 +384,7 @@ def create_ui(): prune, re_basin, re_basin_iterations, - device, + work_device, bake_in_vae, ], outputs=[