pull/2443/head
AI-Casanova 2023-11-16 19:01:03 -06:00
parent 55363254d4
commit db20da7b1a
6 changed files with 15 additions and 14 deletions

View File

@ -9,11 +9,12 @@ import safetensors.torch
import torch
from tqdm import tqdm
import modules.memstats
from modules.shared import log
from modules.merging import merge_methods
from modules.merging.utils import WeightClass
from modules.merging.model import SDModel
from modules.merging.rebasin import (
from modules.merging.merge_utils import WeightClass
from modules.merging.merge_model import SDModel
from modules.merging.merge_rebasin import (
apply_permutation,
sdunet_permutation_spec,
update_model_a,
@ -95,8 +96,7 @@ def restore_sd_model(original_model: Dict, merged_model: Dict) -> Dict:
def log_vram(txt=""):
alloc = torch.cuda.memory_allocated(0)
log.info(f"{txt} VRAM: {alloc*1e-9:5.3f}GB")
log.info(f"{txt} VRAM: {modules.memstats.memory_stats()}")
def load_thetas(

View File

@ -2,9 +2,12 @@ import inspect
# import logging
import re
from modules.merging import merge_methods
from modules.merging.presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS
from modules.merging.merge_presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS
ALL_PRESETS = {}
ALL_PRESETS.update(BLOCK_WEIGHTS_PRESETS)
ALL_PRESETS.update(SDXL_BLOCK_WEIGHTS_PRESETS)
BLOCK_WEIGHTS_PRESETS |= SDXL_BLOCK_WEIGHTS_PRESETS
MERGE_METHODS = dict(inspect.getmembers(merge_methods, inspect.isfunction))
BETA_METHODS = [
name

View File

@ -11,8 +11,8 @@ from modules.shared import opts, log, req
import modules.errors
import modules.hashes
from modules.merging import merge_methods
from modules.merging.utils import BETA_METHODS, TRIPLE_METHODS, interpolate
from modules.merging.presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS
from modules.merging.merge_utils import BETA_METHODS, TRIPLE_METHODS, interpolate
from modules.merging.merge_presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS
search_metadata_civit = None
@ -242,7 +242,7 @@ def create_ui():
with FormRow():
precision = gr.Radio(choices=["fp16", "fp32"], value="fp16", label="Model precision")
with FormRow():
device = gr.Radio(choices=["cpu", "cuda"], value="cpu", label="Device")
work_device = gr.Radio(choices=["cpu", "cuda"], value="cpu", label="Device")
with FormRow():
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None",
interactive=True, label="Bake in VAE")
@ -281,7 +281,7 @@ def create_ui():
prune,
re_basin,
re_basin_iterations,
device,
work_device,
bake_in_vae):
kwargs = {}
for x in inspect.getfullargspec(MEHmodelmerger)[0]:
@ -290,8 +290,6 @@ def create_ui():
if kwargs[key] in [None, "None", "", 0, []]:
del kwargs[key]
# return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"{kwargs}"]
try:
results = extras.run_MEHmodelmerger(dummy_component, **kwargs)
except Exception as e:
@ -386,7 +384,7 @@ def create_ui():
prune,
re_basin,
re_basin_iterations,
device,
work_device,
bake_in_vae,
],
outputs=[