import os
import time
import json
import inspect
from datetime import datetime
import gradio as gr
from modules import sd_models, sd_vae, extras
from modules.ui_components import ToolButton
from modules.ui_common import create_refresh_button
from modules.call_queue import wrap_gradio_gpu_call
from modules.shared import opts, log, req, readfile, max_workers
import modules.ui_symbols
import modules.errors
import modules.hashes
from modules.merging import merge_methods
from modules.merging.merge_utils import BETA_METHODS, TRIPLE_METHODS, interpolate
from modules.merging.merge_presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS
search_metadata_civit = None
def create_ui():
dummy_component = gr.Label(visible=False)
with gr.Row(elem_id="models_tab"):
with gr.Column(elem_id='models_output_container', scale=1):
# models_output = gr.Text(elem_id="models_output", value="", show_label=False)
gr.HTML(elem_id="models_progress", value="")
models_image = gr.Image(elem_id="models_image", show_label=False, interactive=False, type='pil')
models_outcome = gr.HTML(elem_id="models_error", value="")
with gr.Column(elem_id='models_input_container', scale=3):
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}
with gr.Tab(label="Convert"):
with gr.Row():
model_name = gr.Dropdown(sd_models.checkpoint_tiles(), label="Original model")
create_refresh_button(model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_Z")
with gr.Row():
custom_name = gr.Textbox(label="Output model name")
with gr.Row():
precision = gr.Radio(choices=["fp32", "fp16", "bf16"], value="fp16", label="Model precision")
m_type = gr.Radio(choices=["disabled", "no-ema", "ema-only"], value="disabled", label="Model pruning methods")
with gr.Row():
checkpoint_formats = gr.CheckboxGroup(choices=["ckpt", "safetensors"], value=["safetensors"], label="Model Format")
with gr.Row():
show_extra_options = gr.Checkbox(label="Show extra options", value=False)
fix_clip = gr.Checkbox(label="Fix clip", value=False)
with gr.Row(visible=False) as extra_options:
specific_part_conv = ["copy", "convert", "delete"]
unet_conv = gr.Dropdown(specific_part_conv, value="convert", label="unet")
text_encoder_conv = gr.Dropdown(specific_part_conv, value="convert", label="text encoder")
vae_conv = gr.Dropdown(specific_part_conv, value="convert", label="vae")
others_conv = gr.Dropdown(specific_part_conv, value="convert", label="others")
show_extra_options.change(fn=lambda x: gr_show(x), inputs=[show_extra_options], outputs=[extra_options])
model_converter_convert = gr.Button(label="Convert", variant='primary')
model_converter_convert.click(
fn=extras.run_modelconvert,
inputs=[
model_name,
checkpoint_formats,
precision, m_type, custom_name,
unet_conv,
text_encoder_conv,
vae_conv,
others_conv,
fix_clip
],
outputs=[models_outcome]
)
with gr.Tab(label="Merge"):
def sd_model_choices():
return ['None'] + sd_models.checkpoint_tiles()
with gr.Row(equal_height=False):
with gr.Column(variant='compact'):
with gr.Row():
custom_name = gr.Textbox(label="New model name")
with gr.Row():
merge_mode = gr.Dropdown(choices=merge_methods.__all__, value="weighted_sum", label="Interpolation Method")
merge_mode_docs = gr.HTML(value=getattr(merge_methods, "weighted_sum", "").__doc__.replace("\n", "
"))
with gr.Row():
primary_model_name = gr.Dropdown(sd_model_choices(), label="Primary model", value="None")
create_refresh_button(primary_model_name, sd_models.list_models, lambda: {"choices": sd_model_choices()}, "refresh_checkpoint_A")
secondary_model_name = gr.Dropdown(sd_model_choices(), label="Secondary model", value="None")
create_refresh_button(secondary_model_name, sd_models.list_models, lambda: {"choices": sd_model_choices()}, "refresh_checkpoint_B")
tertiary_model_name = gr.Dropdown(sd_model_choices(), label="Tertiary model", value="None", visible=False)
tertiary_refresh = create_refresh_button(tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_model_choices()}, "refresh_checkpoint_C", visible=False)
with gr.Row():
with gr.Tabs() as tabs:
with gr.TabItem(label="Simple Merge", id=0):
with gr.Row():
alpha = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Alpha Ratio', value=0.5)
beta = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Beta Ratio', value=None, visible=False)
with gr.TabItem(label="Preset Block Merge", id=1):
with gr.Row():
sdxl = gr.Checkbox(label="SDXL")
with gr.Row():
alpha_preset = gr.Dropdown(
choices=["None"] + list(BLOCK_WEIGHTS_PRESETS.keys()), value=None,
label="ALPHA Block Weight Preset", multiselect=True, max_choices=2)
alpha_preset_lambda = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Preset Interpolation Ratio', value=None, visible=False)
apply_preset = ToolButton('⇨', visible=True)
with gr.Row():
beta_preset = gr.Dropdown(choices=["None"] + list(BLOCK_WEIGHTS_PRESETS.keys()), value=None, label="BETA Block Weight Preset", multiselect=True, max_choices=2, interactive=True, visible=False)
beta_preset_lambda = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Preset Interpolation Ratio', value=None, interactive=True, visible=False)
beta_apply_preset = ToolButton('⇨', interactive=True, visible=False)
with gr.TabItem(label="Manual Block Merge", id=2):
with gr.Row():
alpha_label = gr.Markdown("# Alpha")
with gr.Row():
alpha_base = gr.Textbox(value=None, label="Base", min_width=70, scale=1)
alpha_in_blocks = gr.Textbox(value=None, label="In Blocks", scale=15)
alpha_mid_block = gr.Textbox(value=None, label="Mid Block", min_width=80, scale=1)
alpha_out_blocks = gr.Textbox(value=None, label="Out Block", scale=15)
with gr.Row():
beta_label = gr.Markdown("# Beta", visible=False)
with gr.Row():
beta_base = gr.Textbox(value=None, label="Base", min_width=70, scale=1, interactive=True, visible=False)
beta_in_blocks = gr.Textbox(value=None, label="In Blocks", interactive=True, scale=15, visible=False)
beta_mid_block = gr.Textbox(value=None, label="Mid Block", min_width=80, interactive=True, scale=1, visible=False)
beta_out_blocks = gr.Textbox(value=None, label="Out Block", interactive=True, scale=15, visible=False)
with gr.Row():
overwrite = gr.Checkbox(label="Overwrite model")
with gr.Row():
save_metadata = gr.Checkbox(value=True, label="Save metadata")
with gr.Row():
weights_clip = gr.Checkbox(label="Weights clip")
prune = gr.Checkbox(label="Prune", value=True, visible=False)
with gr.Row():
re_basin = gr.Checkbox(label="ReBasin")
re_basin_iterations = gr.Slider(minimum=0, maximum=25, step=1, label='Number of ReBasin Iterations', value=None, visible=False)
with gr.Row():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", visible=False, label="Model format")
with gr.Row():
precision = gr.Radio(choices=["fp16", "fp32"], value="fp16", label="Model precision")
with gr.Row():
device = gr.Radio(choices=["cpu", "shuffle", "gpu"], value="cpu", label="Merge Device")
unload = gr.Checkbox(label="Unload Current Model from VRAM", value=False, visible=False)
with gr.Row():
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", interactive=True, label="Replace VAE")
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list,
lambda: {"choices": ["None"] + list(sd_vae.vae_dict)},
"modelmerger_refresh_bake_in_vae")
with gr.Row():
modelmerger_merge = gr.Button(value="Merge", variant='primary')
def modelmerger(dummy_component, # dummy function just to get argspec later
overwrite, # pylint: disable=unused-argument
primary_model_name, # pylint: disable=unused-argument
secondary_model_name, # pylint: disable=unused-argument
tertiary_model_name, # pylint: disable=unused-argument
merge_mode, # pylint: disable=unused-argument
alpha, # pylint: disable=unused-argument
beta, # pylint: disable=unused-argument
alpha_preset, # pylint: disable=unused-argument
alpha_preset_lambda, # pylint: disable=unused-argument
alpha_base, # pylint: disable=unused-argument
alpha_in_blocks, # pylint: disable=unused-argument
alpha_mid_block, # pylint: disable=unused-argument
alpha_out_blocks, # pylint: disable=unused-argument
beta_preset, # pylint: disable=unused-argument
beta_preset_lambda, # pylint: disable=unused-argument
beta_base, # pylint: disable=unused-argument
beta_in_blocks, # pylint: disable=unused-argument
beta_mid_block, # pylint: disable=unused-argument
beta_out_blocks, # pylint: disable=unused-argument
precision, # pylint: disable=unused-argument
custom_name, # pylint: disable=unused-argument
checkpoint_format, # pylint: disable=unused-argument
save_metadata, # pylint: disable=unused-argument
weights_clip, # pylint: disable=unused-argument
prune, # pylint: disable=unused-argument
re_basin, # pylint: disable=unused-argument
re_basin_iterations, # pylint: disable=unused-argument
device, # pylint: disable=unused-argument
unload, # pylint: disable=unused-argument
bake_in_vae): # pylint: disable=unused-argument
kwargs = {}
for x in inspect.getfullargspec(modelmerger)[0]:
kwargs[x] = locals()[x]
for key in list(kwargs.keys()):
if kwargs[key] in [None, "None", "", 0, []]:
del kwargs[key]
del kwargs['dummy_component']
if kwargs.get("custom_name", None) is None:
log.error('Merge: no output model specified')
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "No output model specified"]
elif kwargs.get("primary_model_name", None) is None or kwargs.get("secondary_model_name", None) is None:
log.error('Merge: no models selected')
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "No models selected"]
else:
log.debug(f'Merge start: {kwargs}')
try:
results = extras.run_modelmerger(dummy_component, **kwargs)
except Exception as e:
modules.errors.display(e, 'Merge')
sd_models.list_models() # to remove the potentially missing models from the list
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
def tertiary(mode):
if mode in TRIPLE_METHODS:
return [gr.update(visible=True) for _ in range(2)]
else:
return [gr.update(visible=False) for _ in range(2)]
def beta_visibility(mode):
if mode in BETA_METHODS:
return [gr.update(visible=True) for _ in range(9)]
else:
return [gr.update(visible=False) for _ in range(9)]
def show_iters(show):
if show:
return gr.Slider.update(value=5, visible=True)
else:
return gr.Slider.update(value=None, visible=False)
def show_help(mode):
doc = getattr(merge_methods, mode).__doc__.replace("\n", "
")
return gr.update(value=doc, visible=True)
def show_unload(device):
if device == "gpu":
return gr.update(visible=True)
else:
return gr.update(visible=False)
def preset_visiblility(x):
if len(x) == 2:
return gr.Slider.update(value=0.5, visible=True)
else:
return gr.Slider.update(value=None, visible=False)
def load_presets(presets, ratio):
for i, p in enumerate(presets):
presets[i] = BLOCK_WEIGHTS_PRESETS[p]
if len(presets) == 2:
preset = interpolate(presets, ratio)
else:
preset = presets[0]
preset = ['%.3f' % x if int(x) != x else str(x) for x in preset] # pylint: disable=consider-using-f-string
preset = [preset[0], ",".join(preset[1:13]), preset[13], ",".join(preset[14:])]
return [gr.update(value=x) for x in preset] + [gr.update(selected=2)]
def preset_choices(sdxl):
if sdxl:
return [gr.update(choices=["None"] + list(SDXL_BLOCK_WEIGHTS_PRESETS.keys())) for _ in range(2)]
else:
return [gr.update(choices=["None"] + list(BLOCK_WEIGHTS_PRESETS.keys())) for _ in range(2)]
device.change(fn=show_unload, inputs=device, outputs=unload)
merge_mode.change(fn=show_help, inputs=merge_mode, outputs=merge_mode_docs)
sdxl.change(fn=preset_choices, inputs=sdxl, outputs=[alpha_preset, beta_preset])
alpha_preset.change(fn=preset_visiblility, inputs=alpha_preset, outputs=alpha_preset_lambda)
beta_preset.change(fn=preset_visiblility, inputs=alpha_preset, outputs=beta_preset_lambda)
merge_mode.input(fn=tertiary, inputs=merge_mode, outputs=[tertiary_model_name, tertiary_refresh])
merge_mode.input(fn=beta_visibility, inputs=merge_mode, outputs=[beta, alpha_label, beta_label, beta_apply_preset, beta_preset, beta_base, beta_in_blocks, beta_mid_block, beta_out_blocks])
re_basin.change(fn=show_iters, inputs=re_basin, outputs=re_basin_iterations)
apply_preset.click(fn=load_presets, inputs=[alpha_preset, alpha_preset_lambda], outputs=[alpha_base, alpha_in_blocks, alpha_mid_block, alpha_out_blocks, tabs])
beta_apply_preset.click(fn=load_presets, inputs=[beta_preset, beta_preset_lambda], outputs=[beta_base, beta_in_blocks, beta_mid_block, beta_out_blocks, tabs])
modelmerger_merge.click(
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
_js='modelmerger',
inputs=[
dummy_component,
overwrite,
primary_model_name,
secondary_model_name,
tertiary_model_name,
merge_mode,
alpha,
beta,
alpha_preset,
alpha_preset_lambda,
alpha_base,
alpha_in_blocks,
alpha_mid_block,
alpha_out_blocks,
beta_preset,
beta_preset_lambda,
beta_base,
beta_in_blocks,
beta_mid_block,
beta_out_blocks,
precision,
custom_name,
checkpoint_format,
save_metadata,
weights_clip,
prune,
re_basin,
re_basin_iterations,
device,
unload,
bake_in_vae,
],
outputs=[
primary_model_name,
secondary_model_name,
tertiary_model_name,
dummy_component,
models_outcome,
]
)
with gr.Tab(label="Validate"):
model_headers = ['name', 'type', 'filename', 'hash', 'added', 'size', 'metadata']
model_data = []
with gr.Row():
model_list_btn = gr.Button(value="List model details", variant='primary')
model_checkhash_btn = gr.Button(value="Calculate hash for all models", variant='primary')
model_checkhash_btn.click(fn=sd_models.update_model_hashes, inputs=[], outputs=[models_outcome])
with gr.Row():
model_table = gr.DataFrame(
value=None,
headers=model_headers,
label='Model data',
show_label=True,
interactive=False,
wrap=True,
overflow_row_behaviour='paginate',
max_rows=50,
)
def list_models():
total_size = 0
model_data.clear()
txt = ''
for m in sd_models.checkpoints_list.values():
try:
stat = os.stat(m.filename)
m_name = m.name.replace('.ckpt', '').replace('.safetensors', '')
m_type = 'ckpt' if m.name.endswith('.ckpt') else 'safe'
m_meta = len(json.dumps(m.metadata)) - 2
m_size = round(stat.st_size / 1024 / 1024 / 1024, 3)
m_time = datetime.fromtimestamp(stat.st_mtime)
model_data.append([m_name, m_type, m.filename, m.shorthash, m_time, m_size, m_meta])
total_size += stat.st_size
except Exception as e:
txt += f"Error: {m.name} {e}
"
txt += f"Model list enumerated {len(sd_models.checkpoints_list.keys())} models in {round(total_size / 1024 / 1024 / 1024, 3)} GB
"
return model_data, txt
model_list_btn.click(fn=list_models, inputs=[], outputs=[model_table, models_outcome])
with gr.Tab(label="Huggingface"):
data = []
os.environ.setdefault('HF_HUB_DISABLE_EXPERIMENTAL_WARNING', '1')
os.environ.setdefault('HF_HUB_DISABLE_SYMLINKS_WARNING', '1')
os.environ.setdefault('HF_HUB_DISABLE_IMPLICIT_TOKEN', '1')
os.environ.setdefault('HUGGINGFACE_HUB_VERBOSITY', 'warning')
def hf_search(keyword):
import huggingface_hub as hf
hf_api = hf.HfApi()
model_filter = hf.ModelFilter(model_name=keyword, library=['diffusers'])
models = hf_api.list_models(filter=model_filter, full=True, limit=50, sort="downloads", direction=-1)
data.clear()
for model in models:
tags = [t for t in model.tags if not t.startswith('diffusers') and not t.startswith('license') and not t.startswith('arxiv') and len(t) > 2]
data.append([model.modelId, model.pipeline_tag, tags, model.downloads, model.lastModified, f'https://huggingface.co/{model.modelId}'])
return data
def hf_select(evt: gr.SelectData, data):
return data[evt.index[0]][0]
def hf_download_model(hub_id: str, token, variant, revision, mirror, custom_pipeline):
from modules.modelloader import download_diffusers_model
download_diffusers_model(hub_id, cache_dir=opts.diffusers_dir, token=token, variant=variant, revision=revision, mirror=mirror, custom_pipeline=custom_pipeline)
from modules.sd_models import list_models # pylint: disable=W0621
list_models()
log.info(f'Diffuser model downloaded: model="{hub_id}"')
return f'Diffuser model downloaded: model="{hub_id}"'
with gr.Column(scale=6):
gr.HTML('