mirror of https://github.com/vladmandic/automatic
Fixes
parent
55363254d4
commit
db20da7b1a
|
|
@ -9,11 +9,12 @@ import safetensors.torch
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import modules.memstats
|
||||
from modules.shared import log
|
||||
from modules.merging import merge_methods
|
||||
from modules.merging.utils import WeightClass
|
||||
from modules.merging.model import SDModel
|
||||
from modules.merging.rebasin import (
|
||||
from modules.merging.merge_utils import WeightClass
|
||||
from modules.merging.merge_model import SDModel
|
||||
from modules.merging.merge_rebasin import (
|
||||
apply_permutation,
|
||||
sdunet_permutation_spec,
|
||||
update_model_a,
|
||||
|
|
@ -95,8 +96,7 @@ def restore_sd_model(original_model: Dict, merged_model: Dict) -> Dict:
|
|||
|
||||
|
||||
def log_vram(txt=""):
|
||||
alloc = torch.cuda.memory_allocated(0)
|
||||
log.info(f"{txt} VRAM: {alloc*1e-9:5.3f}GB")
|
||||
log.info(f"{txt} VRAM: {modules.memstats.memory_stats()}")
|
||||
|
||||
|
||||
def load_thetas(
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@ import inspect
|
|||
# import logging
|
||||
import re
|
||||
from modules.merging import merge_methods
|
||||
from modules.merging.presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS
|
||||
from modules.merging.merge_presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS
|
||||
|
||||
ALL_PRESETS = {}
|
||||
ALL_PRESETS.update(BLOCK_WEIGHTS_PRESETS)
|
||||
ALL_PRESETS.update(SDXL_BLOCK_WEIGHTS_PRESETS)
|
||||
|
||||
BLOCK_WEIGHTS_PRESETS |= SDXL_BLOCK_WEIGHTS_PRESETS
|
||||
MERGE_METHODS = dict(inspect.getmembers(merge_methods, inspect.isfunction))
|
||||
BETA_METHODS = [
|
||||
name
|
||||
|
|
@ -11,8 +11,8 @@ from modules.shared import opts, log, req
|
|||
import modules.errors
|
||||
import modules.hashes
|
||||
from modules.merging import merge_methods
|
||||
from modules.merging.utils import BETA_METHODS, TRIPLE_METHODS, interpolate
|
||||
from modules.merging.presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS
|
||||
from modules.merging.merge_utils import BETA_METHODS, TRIPLE_METHODS, interpolate
|
||||
from modules.merging.merge_presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS
|
||||
|
||||
search_metadata_civit = None
|
||||
|
||||
|
|
@ -242,7 +242,7 @@ def create_ui():
|
|||
with FormRow():
|
||||
precision = gr.Radio(choices=["fp16", "fp32"], value="fp16", label="Model precision")
|
||||
with FormRow():
|
||||
device = gr.Radio(choices=["cpu", "cuda"], value="cpu", label="Device")
|
||||
work_device = gr.Radio(choices=["cpu", "cuda"], value="cpu", label="Device")
|
||||
with FormRow():
|
||||
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None",
|
||||
interactive=True, label="Bake in VAE")
|
||||
|
|
@ -281,7 +281,7 @@ def create_ui():
|
|||
prune,
|
||||
re_basin,
|
||||
re_basin_iterations,
|
||||
device,
|
||||
work_device,
|
||||
bake_in_vae):
|
||||
kwargs = {}
|
||||
for x in inspect.getfullargspec(MEHmodelmerger)[0]:
|
||||
|
|
@ -290,8 +290,6 @@ def create_ui():
|
|||
if kwargs[key] in [None, "None", "", 0, []]:
|
||||
del kwargs[key]
|
||||
|
||||
# return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"{kwargs}"]
|
||||
|
||||
try:
|
||||
results = extras.run_MEHmodelmerger(dummy_component, **kwargs)
|
||||
except Exception as e:
|
||||
|
|
@ -386,7 +384,7 @@ def create_ui():
|
|||
prune,
|
||||
re_basin,
|
||||
re_basin_iterations,
|
||||
device,
|
||||
work_device,
|
||||
bake_in_vae,
|
||||
],
|
||||
outputs=[
|
||||
|
|
|
|||
Loading…
Reference in New Issue