sd_dreambooth_extension/scripts/main.py

2055 lines
85 KiB
Python

import importlib
import json
import time
from typing import List
import gradio as gr
from dreambooth.dataclasses.db_config import from_file, save_config
from dreambooth.diff_to_sd import compile_checkpoint
from dreambooth.diff_to_sdxl import compile_checkpoint as compile_checkpoint_sdxl
from dreambooth.secret import (
get_secret,
create_secret,
clear_secret,
)
from dreambooth.shared import (
status,
get_launch_errors,
)
from dreambooth.ui_functions import (
performance_wizard,
training_wizard,
training_wizard_person,
load_model_params,
ui_classifiers,
debug_buckets,
create_model,
generate_samples,
load_params,
start_training,
update_extension,
start_crop,
)
from dreambooth.utils.image_utils import (
get_scheduler_names,
)
from dreambooth.utils.model_utils import (
get_db_models,
get_sorted_lora_models,
get_model_snapshots,
get_shared_models,
)
from dreambooth.utils.utils import (
list_attention,
list_precisions,
wrap_gpu_call,
printm,
list_optimizer,
list_schedulers, select_precision, select_attention,
)
from dreambooth.webhook import save_and_test_webhook
from helpers.log_parser import LogParser
from helpers.version_helper import check_updates
from modules import script_callbacks, sd_models
from modules.ui import gr_show, create_refresh_button
from preprocess.preprocess_utils import check_preprocess_path, load_image_caption
preprocess_params = []
params_to_save = []
params_to_load = []
refresh_symbol = "\U0001f504" # 🔄
delete_symbol = "\U0001F5D1" # 🗑️
update_symbol = "\U0001F51D" # 🠝
log_parser = LogParser()
show_advanced = True
def read_metadata_from_safetensors(filename):
with open(filename, mode="rb") as file:
# Read metadata length
metadata_len = int.from_bytes(file.read(8), "little")
# Read the metadata based on its length
json_data = file.read(metadata_len).decode('utf-8')
res = {}
# Check if it's a valid JSON string
try:
json_obj = json.loads(json_data)
except json.JSONDecodeError:
return res
# Extract metadata
metadata = json_obj.get("__metadata__", {})
if not isinstance(metadata, dict):
return res
# Process the metadata to handle nested JSON strings
for k, v in metadata.items():
# if not isinstance(v, str):
# raise ValueError("All values in __metadata__ must be strings")
# If the string value looks like a JSON string, attempt to parse it
if v.startswith('{'):
try:
res[k] = json.loads(v)
except Exception:
res[k] = v
else:
res[k] = v
return res
def get_sd_models():
sd_models.list_models()
sd_list = sd_models.checkpoints_list
names = []
for key in sd_list:
names.append(key)
return names
def calc_time_left(progress, threshold, label, force_display):
if progress == 0:
return ""
else:
if status.time_start is None:
time_since_start = 0
else:
time_since_start = time.time() - status.time_start
eta = time_since_start / progress
eta_relative = eta - time_since_start
if (eta_relative > threshold and progress > 0.02) or force_display:
if eta_relative > 86400:
days = eta_relative // 86400
remainder = days * 86400
eta_relative -= remainder
return f"{label}{days}:{time.strftime('%H:%M:%S', time.gmtime(eta_relative))}"
if eta_relative > 3600:
return label + time.strftime("%H:%M:%S", time.gmtime(eta_relative))
elif eta_relative > 60:
return label + time.strftime("%M:%S", time.gmtime(eta_relative))
else:
return label + time.strftime("%Ss", time.gmtime(eta_relative))
else:
return ""
def has_face_swap():
script_class = None
try:
from modules.scripts import list_scripts
scripts = list_scripts("scripts", ".py")
for script_file in scripts:
if script_file.filename == "batch_face_swap.py":
path = script_file.path
module_name = "batch_face_swap"
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
script_class = module.Script
break
except Exception as f:
print(f"Can't check face swap: {f}")
return script_class is not None
def check_progress_call():
"""
Check the progress from share dreamstate and return appropriate UI elements.
@return:
active: Checkbox to physically hold an active state
pspan: Progress bar span contents
preview: Preview Image/Visibility
gallery: Gallery Image/Visibility
textinfo_result: Primary status
sample_prompts: List = A list of prompts corresponding with gallery contents
check_progress_initial: Hides the manual 'check progress' button
"""
active_box = gr.update(value=status.active)
if not status.active:
return (
active_box,
"",
gr.update(visible=False, value=None),
gr.update(visible=True),
gr_show(True),
gr_show(True),
gr_show(False),
)
progress = 0
if status.job_count > 0:
progress += status.job_no / status.job_count
time_left = calc_time_left(progress, 1, " ETA: ", status.time_left_force_display)
if time_left:
status.time_left_force_display = True
progress = min(progress, 1)
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress * 100)) + "%" + time_left if progress > 0.01 else ""}</div></div>"""
status.set_current_image()
image = status.current_image
preview = None
gallery = None
if image is None:
preview = gr.update(visible=False, value=None)
gallery = gr.update(visible=True)
else:
if isinstance(image, List):
if len(image) > 1:
status.current_image = None
preview = gr.update(visible=False, value=None)
gallery = gr.update(visible=True, value=image)
elif len(image) == 1:
preview = gr.update(visible=True, value=image[0])
gallery = gr.update(visible=True, value=None)
else:
preview = gr.update(visible=True, value=image)
gallery = gr.update(visible=True, value=None)
if status.textinfo is not None:
textinfo_result = status.textinfo
else:
textinfo_result = ""
if status.textinfo2 is not None:
textinfo_result = f"{textinfo_result}<br>{status.textinfo2}"
prompts = ""
if len(status.sample_prompts) > 0:
if len(status.sample_prompts) > 1:
prompts = "<br>".join(status.sample_prompts)
else:
prompts = status.sample_prompts[0]
pspan = f"<span id='db_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>"
return (
active_box,
pspan,
preview,
gallery,
textinfo_result,
gr.update(value=prompts),
gr_show(False),
)
def check_progress_call_initial():
status.begin()
(
active_box,
pspan,
preview,
gallery,
textinfo_result,
prompts_result,
pbutton_result,
) = check_progress_call()
return (
active_box,
pspan,
gr_show(False),
gr.update(value=[]),
textinfo_result,
gr.update(value=[]),
gr_show(False),
)
def ui_gen_ckpt(model_name: str):
if isinstance(model_name, List):
model_name = model_name[0]
if model_name == "" or model_name is None:
return "Please select a model."
config = from_file(model_name)
printm("Config loaded")
lora_path = config.lora_model_name
print(f"Lora path: {lora_path}")
if config.model_type == "SDXL":
res = compile_checkpoint_sdxl(model_name, lora_path, True, False, config.snapshot)
else:
res = compile_checkpoint(model_name, lora_path, True, True, config.snapshot)
return res
def on_ui_tabs():
with gr.Blocks() as dreambooth_interface:
# Top button row
with gr.Row(equal_height=True, elem_id="DbTopRow"):
db_load_params = gr.Button(value="Load Settings", elem_id="db_load_params", size="sm")
db_save_params = gr.Button(value="Save Settings", elem_id="db_save_config", size="sm")
db_train_model = gr.Button(
value="Train", variant="primary", elem_id="db_train", size="sm"
)
db_generate_checkpoint = gr.Button(
value="Generate Ckpt", elem_id="db_gen_ckpt", size="sm"
)
db_generate_checkpoint_during = gr.Button(
value="Save Weights", elem_id="db_gen_ckpt_during", size="sm"
)
db_train_sample = gr.Button(
value="Generate Samples", elem_id="db_train_sample", size="sm"
)
db_cancel = gr.Button(value="Cancel", elem_id="db_cancel", size="sm")
with gr.Row():
gr.HTML(value="Select or create a model to begin.", elem_id="hint_row")
with gr.Row(elem_id="ModelDetailRow", visible=False, variant="compact") as db_model_info:
with gr.Column():
with gr.Row(variant="compact"):
with gr.Column():
with gr.Row(variant="compact"):
gr.HTML(value="Loaded Model:")
db_model_path = gr.HTML()
with gr.Row(variant="compact"):
gr.HTML(value="Source Checkpoint:")
db_src = gr.HTML()
with gr.Column():
with gr.Row(variant="compact"):
gr.HTML(value="Model Epoch:")
db_epochs = gr.HTML(elem_id="db_epochs")
with gr.Row(variant="compact"):
gr.HTML(value="Model Revision:")
db_revision = gr.HTML(elem_id="db_revision")
with gr.Column():
with gr.Row(variant="compact"):
gr.HTML(value="Model type:")
db_model_type = gr.HTML(elem_id="db_model_type")
with gr.Row(variant="compact"):
gr.HTML(value="Has EMA:")
db_has_ema = gr.HTML(elem_id="db_has_ema")
with gr.Row(variant="compact", visible=False):
gr.HTML(value="Experimental Shared Source:")
db_shared_diffusers_path = gr.HTML()
with gr.Row(equal_height=False):
with gr.Column(variant="panel", elem_id="SettingsPanel"):
with gr.Row():
with gr.Column(scale=1, min_width=100, elem_classes="halfElement"):
gr.HTML(value="<span class='hh'>Settings</span>")
with gr.Column(scale=1, min_width=100, elem_classes="halfElement"):
db_show_advanced = gr.Button(value="Show Advanced", size="sm", elem_classes="advBtn", visible=False)
db_hide_advanced = gr.Button(value="Hide Advanced", variant="primary", size="sm", elem_id="db_hide_advanced", elem_classes="advBtn")
with gr.Tab("Model", elem_id="ModelPanel"):
with gr.Column():
with gr.Tab("Select"):
with gr.Row():
db_model_name = gr.Dropdown(
label="Model", choices=sorted(get_db_models())
)
create_refresh_button(
db_model_name,
get_db_models,
lambda: {"choices": sorted(get_db_models())},
"refresh_db_models",
)
with gr.Row() as db_snapshot_row:
db_snapshot = gr.Dropdown(
label="Snapshot to Resume",
choices=sorted(get_model_snapshots()),
)
create_refresh_button(
db_snapshot,
get_model_snapshots,
lambda: {"choices": sorted(get_model_snapshots())},
"refresh_db_snapshots",
)
with gr.Row(visible=False) as lora_model_row:
db_lora_model_name = gr.Dropdown(
label="Lora Model", choices=get_sorted_lora_models()
)
create_refresh_button(
db_lora_model_name,
get_sorted_lora_models,
lambda: {"choices": get_sorted_lora_models()},
"refresh_lora_models",
)
with gr.Tab("Create"):
with gr.Column():
db_create_model = gr.Button(
value="Create Model", variant="primary"
)
db_new_model_name = gr.Textbox(label="Name")
with gr.Row():
db_create_from_hub = gr.Checkbox(
label="Create From Hub", value=False
)
db_model_type_select = gr.Dropdown(label="Model Type",
choices=["v1x", "v2x-512", "v2x", "SDXL",
"ControlNet"], value="v1x")
db_use_shared_src = gr.Checkbox(
label="Experimental Shared Src", value=False, visible=False
)
with gr.Column(visible=False) as hub_row:
db_new_model_url = gr.Textbox(
label="Model Path",
placeholder="runwayml/stable-diffusion-v1-5",
)
db_new_model_token = gr.Textbox(
label="HuggingFace Token", value=""
)
with gr.Column(visible=True) as local_row:
with gr.Row():
db_new_model_src = gr.Dropdown(
label="Source Checkpoint",
choices=sorted(get_sd_models()),
)
create_refresh_button(
db_new_model_src,
get_sd_models,
lambda: {"choices": sorted(get_sd_models())},
"refresh_sd_models",
)
with gr.Column(visible=False) as shared_row:
with gr.Row():
db_new_model_shared_src = gr.Dropdown(
label="EXPERIMENTAL: LoRA Shared Diffusers Source",
choices=sorted(get_shared_models()),
value="",
visible=False
)
create_refresh_button(
db_new_model_shared_src,
get_shared_models,
lambda: {"choices": sorted(get_shared_models())},
"refresh_shared_models",
)
db_new_model_extract_ema = gr.Checkbox(
label="Extract EMA Weights", value=False
)
db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=True)
with gr.Column():
with gr.Accordion(open=False, label="Resources"):
with gr.Column():
gr.HTML(
value="<a class=\"hyperlink\" href=\"https://github.com/d8ahazard/sd_dreambooth_extension/wiki/ELI5-Training\">Beginners guide</a>",
)
gr.HTML(
value="<a class=\"hyperlink\" href=\"https://github.com/d8ahazard/sd_dreambooth_extension/releases/latest\">Release notes</a>",
)
# with gr.Tab("Preprocess", elem_id="PreprocessPanel", visible=False):
# with gr.Row():
# with gr.Column(scale=2, variant="compact"):
# db_preprocess_path = gr.Textbox(
# label="Image Path", value="", placeholder="Enter the path to your images"
# )
# with gr.Column(variant="compact"):
# db_preprocess_recursive = gr.Checkbox(
# label="Recursive", value=False, container=True, elem_classes=["singleCheckbox"]
# )
# with gr.Row():
# with gr.Tab("Auto-Caption"):
# with gr.Row():
# gr.HTML(value="Auto-Caption")
# with gr.Tab("Edit Captions"):
# with gr.Row():
# db_preprocess_autosave = gr.Checkbox(
# label="Autosave", value=False
# )
# with gr.Row():
# gr.HTML(value="Edit Captions")
# with gr.Tab("Edit Images"):
# with gr.Row():
# gr.HTML(value="Edit Images")
# with gr.Row():
# db_preprocess = gr.Button(
# value="Preprocess", variant="primary"
# )
# db_preprocess_all = gr.Button(
# value="Preprocess All", variant="primary"
# )
# with gr.Row():
# db_preprocess_all = gr.Button(
# value="Preprocess All", variant="primary"
# )
with gr.Tab("Concepts", elem_id="TabConcepts") as concept_tab:
with gr.Column(variant="panel"):
with gr.Accordion(open=False, label="Concept 1"):
(
c1_instance_data_dir,
c1_class_data_dir,
c1_instance_prompt,
c1_class_prompt,
c1_save_sample_prompt,
c1_save_sample_template,
c1_instance_token,
c1_class_token,
c1_num_class_images_per,
c1_class_negative_prompt,
c1_class_guidance_scale,
c1_class_infer_steps,
c1_save_sample_negative_prompt,
c1_n_save_sample,
c1_sample_seed,
c1_save_guidance_scale,
c1_save_infer_steps,
) = build_concept_panel(1)
with gr.Accordion(open=False, label="Concept 2"):
(
c2_instance_data_dir,
c2_class_data_dir,
c2_instance_prompt,
c2_class_prompt,
c2_save_sample_prompt,
c2_save_sample_template,
c2_instance_token,
c2_class_token,
c2_num_class_images_per,
c2_class_negative_prompt,
c2_class_guidance_scale,
c2_class_infer_steps,
c2_save_sample_negative_prompt,
c2_n_save_sample,
c2_sample_seed,
c2_save_guidance_scale,
c2_save_infer_steps,
) = build_concept_panel(2)
with gr.Accordion(open=False, label="Concept 3"):
(
c3_instance_data_dir,
c3_class_data_dir,
c3_instance_prompt,
c3_class_prompt,
c3_save_sample_prompt,
c3_save_sample_template,
c3_instance_token,
c3_class_token,
c3_num_class_images_per,
c3_class_negative_prompt,
c3_class_guidance_scale,
c3_class_infer_steps,
c3_save_sample_negative_prompt,
c3_n_save_sample,
c3_sample_seed,
c3_save_guidance_scale,
c3_save_infer_steps,
) = build_concept_panel(3)
with gr.Accordion(open=False, label="Concept 4"):
(
c4_instance_data_dir,
c4_class_data_dir,
c4_instance_prompt,
c4_class_prompt,
c4_save_sample_prompt,
c4_save_sample_template,
c4_instance_token,
c4_class_token,
c4_num_class_images_per,
c4_class_negative_prompt,
c4_class_guidance_scale,
c4_class_infer_steps,
c4_save_sample_negative_prompt,
c4_n_save_sample,
c4_sample_seed,
c4_save_guidance_scale,
c4_save_infer_steps,
) = build_concept_panel(4)
with gr.Tab("Parameters", elem_id="TabSettings"):
db_performance_wizard = gr.Button(value="Performance Wizard (WIP)", visible=False)
with gr.Accordion(open=False, label="Performance"):
db_use_ema = gr.Checkbox(
label="Use EMA", value=False
)
db_optimizer = gr.Dropdown(
label="Optimizer",
value="8bit AdamW",
choices=list_optimizer(),
)
db_mixed_precision = gr.Dropdown(
label="Mixed Precision",
value=select_precision(),
choices=list_precisions(),
)
db_full_mixed_precision = gr.Checkbox(
label="Full Mixed Precision", value=True
)
db_attention = gr.Dropdown(
label="Memory Attention",
value=select_attention(),
choices=list_attention(),
)
db_cache_latents = gr.Checkbox(
label="Cache Latents", value=True
)
db_train_unet = gr.Checkbox(
label="Train UNET", value=True
)
db_stop_text_encoder = gr.Slider(
label="Step Ratio of Text Encoder Training",
minimum=0,
maximum=1,
step=0.05,
value=1.0,
visible=True,
)
db_offset_noise = gr.Slider(
label="Offset Noise",
minimum=-1,
maximum=1,
step=0.01,
value=0,
)
db_freeze_clip_normalization = gr.Checkbox(
label="Freeze CLIP Normalization Layers",
visible=True,
value=False,
)
db_clip_skip = gr.Slider(
label="Clip Skip",
value=2,
minimum=1,
maximum=12,
step=1,
)
db_weight_decay = gr.Slider(
label="Weight Decay",
minimum=0,
maximum=1,
step=0.001,
value=0.01,
visible=True,
)
db_tenc_weight_decay = gr.Slider(
label="TENC Weight Decay",
minimum=0,
maximum=1,
step=0.001,
value=0.01,
visible=True,
)
db_tenc_grad_clip_norm = gr.Slider(
label="TENC Gradient Clip Norm",
minimum=0,
maximum=128,
step=0.25,
value=0,
visible=True,
)
db_min_snr_gamma = gr.Slider(
label="Min SNR Gamma",
minimum=0,
maximum=10,
step=0.1,
visible=True,
)
db_use_dream = gr.Checkbox(
label="Use DREAM", value=False
)
db_dream_detail_preservation = gr.Slider(
label="DREAM detail preservation",
minimum=0,
maximum=1,
step=0.01,
value=0.5,
visible=True,
)
db_freeze_spectral_norm = gr.Checkbox(
label="Freeze Spectral Norm", value=False
)
db_pad_tokens = gr.Checkbox(
label="Pad Tokens", value=True
)
db_strict_tokens = gr.Checkbox(
label="Strict Tokens", value=False
)
db_shuffle_tags = gr.Checkbox(
label="Shuffle Tags", value=True
)
db_max_token_length = gr.Slider(
label="Max Token Length",
minimum=75,
maximum=300,
step=75,
)
with gr.Accordion(open=False, label="Intervals"):
db_num_train_epochs = gr.Slider(
label="Training Steps Per Image (Epochs)",
value=100,
maximum=1000,
step=1,
)
db_epoch_pause_frequency = gr.Slider(
label="Pause After N Epochs",
value=0,
maximum=100,
step=1,
)
db_epoch_pause_time = gr.Slider(
label="Amount of time to pause between Epochs (s)",
value=0,
maximum=3600,
step=1,
)
db_save_embedding_every = gr.Slider(
label="Save Model Frequency (Epochs)",
value=25,
maximum=1000,
step=1,
)
db_save_preview_every = gr.Slider(
label="Save Preview(s) Frequency (Epochs)",
value=5,
maximum=1000,
step=1,
)
with gr.Accordion(open=False, label="Batch Sizes") as db_batch_size_view:
db_train_batch_size = gr.Slider(
label="Batch Size",
value=1,
minimum=1,
maximum=100,
step=1,
)
db_gradient_accumulation_steps = gr.Slider(
label="Gradient Accumulation Steps",
value=1,
minimum=1,
maximum=100,
step=1,
)
db_sample_batch_size = gr.Slider(
label="Class Batch Size",
minimum=1,
maximum=100,
value=1,
step=1,
)
db_gradient_set_to_none = gr.Checkbox(
label="Set Gradients to None When Zeroing", value=True
)
db_gradient_checkpointing = gr.Checkbox(
label="Gradient Checkpointing", value=True
)
with gr.Accordion(open=False, label="Learning Rate"):
with gr.Row(visible=False) as lora_lr_row:
db_lora_learning_rate = gr.Number(
label="Lora UNET Learning Rate", value=1e-4
)
db_lora_txt_learning_rate = gr.Number(
label="Lora Text Encoder Learning Rate", value=5e-5
)
with gr.Row() as standard_lr_row:
db_learning_rate = gr.Number(
label="Learning Rate", value=2e-6
)
db_txt_learning_rate = gr.Number(
label="Text Encoder Learning Rate", value=1e-6
)
db_lr_scheduler = gr.Dropdown(
label="Learning Rate Scheduler",
value="constant_with_warmup",
choices=list_schedulers(),
)
db_learning_rate_min = gr.Number(
label="Min Learning Rate", value=1e-6, visible=False
)
db_lr_cycles = gr.Number(
label="Number of Hard Resets",
value=1,
precision=0,
visible=False,
)
db_lr_factor = gr.Number(
label="Constant/Linear Starting Factor",
value=0.5,
precision=2,
visible=False,
)
db_lr_power = gr.Number(
label="Polynomial Power",
value=1.0,
precision=1,
visible=False,
)
db_lr_scale_pos = gr.Slider(
label="Scale Position",
value=0.5,
minimum=0,
maximum=1,
step=0.05,
visible=False,
)
db_lr_warmup_steps = gr.Slider(
label="Learning Rate Warmup Steps",
value=500,
step=5,
maximum=1000,
)
with gr.Accordion(open=False, label="Lora"):
db_use_lora = gr.Checkbox(label="Use LORA", value=False)
db_use_lora_extended = gr.Checkbox(
label="Use Lora Extended",
value=False,
visible=False,
)
db_train_imagic = gr.Checkbox(label="Train Imagic Only", value=False, visible=False)
db_train_inpainting = gr.Checkbox(
label="Train Inpainting Model",
value=False,
visible=False,
)
with gr.Column(visible=False) as lora_rank_col:
db_lora_unet_rank = gr.Slider(
label="Lora UNET Rank",
value=4,
minimum=2,
maximum=128,
step=2,
)
db_lora_txt_rank = gr.Slider(
label="Lora Text Encoder Rank",
value=4,
minimum=2,
maximum=128,
step=2,
)
db_lora_weight = gr.Slider(
label="Lora Weight (Alpha)",
value=0.8,
minimum=0.1,
maximum=1,
step=0.1,
)
with gr.Accordion(open=False, label="Image Processing"):
db_resolution = gr.Slider(
label="Max Resolution",
step=64,
minimum=128,
value=512,
maximum=2048,
elem_id="max_res",
)
db_hflip = gr.Checkbox(
label="Apply Horizontal Flip", value=False
)
db_dynamic_img_norm = gr.Checkbox(
label="Dynamic Image Normalization", value=False
)
with gr.Accordion(open=False, label="Prior Loss") as db_prior_loss_view:
db_prior_loss_scale = gr.Checkbox(
label="Scale Prior Loss", value=False
)
db_prior_loss_weight = gr.Slider(
label="Prior Loss Weight",
minimum=0.01,
maximum=1,
step=0.01,
value=0.75,
)
db_prior_loss_target = gr.Number(
label="Prior Loss Target",
value=100,
visible=False,
)
db_prior_loss_weight_min = gr.Slider(
label="Minimum Prior Loss Weight",
minimum=0.01,
maximum=1,
step=0.01,
value=0.1,
visible=False,
)
with gr.Accordion(open=False, label="Saving", elme_id="TabSave") as db_save_tab:
with gr.Column():
gr.HTML("General")
db_custom_model_name = gr.Textbox(
label="Custom Model Name",
value="",
placeholder="Enter a model name for saving checkpoints and lora models.",
)
db_save_safetensors = gr.Checkbox(
label="Save in .safetensors format",
value=True,
visible=False,
)
db_save_ema = gr.Checkbox(
label="Save EMA Weights to Generated Models", value=True
)
db_infer_ema = gr.Checkbox(
label="Use EMA Weights for Inference", value=False
)
with gr.Column():
gr.HTML("Checkpoints")
db_half_model = gr.Checkbox(label="Half Model", value=False)
db_use_subdir = gr.Checkbox(
label="Save Checkpoint to Subdirectory", value=True
)
db_save_ckpt_during = gr.Checkbox(
label="Generate a .ckpt file when saving during training."
)
db_save_ckpt_after = gr.Checkbox(
label="Generate a .ckpt file when training completes.",
value=True,
)
db_save_ckpt_cancel = gr.Checkbox(
label="Generate a .ckpt file when training is canceled."
)
with gr.Column(visible=False) as lora_save_col:
db_save_lora_during = gr.Checkbox(
label="Generate lora weights when saving during training."
)
db_save_lora_after = gr.Checkbox(
label="Generate lora weights when training completes.",
value=True,
)
db_save_lora_cancel = gr.Checkbox(
label="Generate lora weights when training is canceled."
)
db_save_lora_for_extra_net = gr.Checkbox(
label="Generate lora weights for extra networks."
)
with gr.Column():
gr.HTML("Diffusion Weights (training snapshots)")
db_save_state_during = gr.Checkbox(
label="Save separate diffusers snapshots when saving during training."
)
db_save_state_after = gr.Checkbox(
label="Save separate diffusers snapshots when training completes."
)
db_save_state_cancel = gr.Checkbox(
label="Save separate diffusers snapshots when training is canceled."
)
with gr.Accordion(open=False, label="Image Generation", elem_id="TabGenerate") as db_generate_tab:
gr.HTML(value="Class Generation Schedulers")
db_class_gen_method = gr.Dropdown(
label="Image Generation Library",
value="Native Diffusers",
choices=[
"A1111 txt2img (Euler a)",
"Native Diffusers",
]
)
db_scheduler = gr.Dropdown(
label="Image Generation Scheduler",
value="DEISMultistep",
choices=get_scheduler_names(),
)
gr.HTML(value="Manual Class Generation")
with gr.Column():
db_generate_classes = gr.Button(value="Generate Class Images")
db_generate_graph = gr.Button(value="Generate Graph")
db_graph_smoothing = gr.Slider(
value=50,
label="Graph Smoothing Steps",
minimum=10,
maximum=500,
)
db_debug_buckets = gr.Button(value="Debug Buckets")
db_bucket_epochs = gr.Slider(
value=10,
step=1,
minimum=1,
maximum=1000,
label="Epochs to Simulate",
)
db_bucket_batch = gr.Slider(
value=1,
step=1,
minimum=1,
maximum=500,
label="Batch Size to Simulate",
)
db_generate_sample = gr.Button(value="Generate Sample Images")
db_sample_prompt = gr.Textbox(label="Sample Prompt")
db_sample_negative = gr.Textbox(label="Sample Negative Prompt")
db_sample_prompt_file = gr.Textbox(label="Sample Prompt File")
db_sample_width = gr.Slider(
label="Sample Width",
value=512,
step=64,
minimum=128,
maximum=2048,
)
db_sample_height = gr.Slider(
label="Sample Height",
value=512,
step=64,
minimum=128,
maximum=2048,
)
db_sample_seed = gr.Number(
label="Sample Seed", value=-1, precision=0
)
db_num_samples = gr.Slider(
label="Number of Samples to Generate",
value=1,
minimum=1,
maximum=1000,
step=1,
)
db_gen_sample_batch_size = gr.Slider(
label="Sample Batch Size",
value=1,
step=1,
minimum=1,
maximum=100,
interactive=True,
)
db_sample_steps = gr.Slider(
label="Sample Steps",
value=20,
minimum=1,
maximum=500,
step=1,
)
db_sample_scale = gr.Slider(
label="Sample CFG Scale",
value=7.5,
step=0.1,
minimum=1,
maximum=20,
)
with gr.Column(variant="panel", visible=has_face_swap()):
db_swap_faces = gr.Checkbox(label="Swap Sample Faces")
db_swap_prompt = gr.Textbox(label="Swap Prompt")
db_swap_negative = gr.Textbox(label="Swap Negative Prompt")
db_swap_steps = gr.Slider(label="Swap Steps", value=40)
db_swap_batch = gr.Slider(label="Swap Batch", value=40)
db_sample_txt2img = gr.Checkbox(
label="Use txt2img",
value=False,
visible=False # db_sample_txt2img not implemented yet
)
with gr.Accordion(open=False, label="Extras"):
with gr.Column():
gr.HTML(value="Sanity Samples")
db_sanity_prompt = gr.Textbox(
label="Sanity Sample Prompt",
placeholder="A generic prompt used to generate a sample image "
"to verify model fidelity.",
)
db_sanity_negative_prompt = gr.Textbox(
label="Sanity Sample Negative Prompt",
placeholder="A negative prompt for the generic sample image.",
)
db_sanity_seed = gr.Number(
label="Sanity Sample Seed", value=420420
)
with gr.Column() as db_misc_view:
gr.HTML(value="Miscellaneous")
db_pretrained_vae_name_or_path = gr.Textbox(
label="Pretrained VAE Name or Path",
placeholder="Leave blank to use base model VAE.",
value="",
)
db_use_concepts = gr.Checkbox(
label="Use Concepts List", value=False
)
db_concepts_path = gr.Textbox(
label="Concepts List",
placeholder="Path to JSON file with concepts to train.",
)
with gr.Row():
db_secret = gr.Textbox(
label="API Key", value=get_secret, interactive=False
)
db_refresh_button = gr.Button(
value=refresh_symbol, elem_id="refresh_secret"
)
db_clear_secret = gr.Button(
value=delete_symbol, elem_id="clear_secret"
)
with gr.Column() as db_hook_view:
gr.HTML(value="Webhooks")
# In the future change this to something more generic and list the supported types
# from DreamboothWebhookTarget enum; for now, Discord is what I use ;)
# Add options to include notifications on training complete and exceptions that halt training
db_notification_webhook_url = gr.Textbox(
label="Discord Webhook",
placeholder="https://discord.com/api/webhooks/XXX/XXXX",
value="",
)
notification_webhook_test_btn = gr.Button(
value="Save and Test Webhook"
)
with gr.Column() as db_test_tab:
gr.HTML(value="Experimental Settings")
db_tomesd = gr.Slider(
value=0,
label="Token Merging (ToMe)",
minimum=0,
maximum=1,
step=0.1,
)
db_split_loss = gr.Checkbox(
label="Calculate Split Loss", value=True
)
db_disable_class_matching = gr.Checkbox(label="Disable Class Matching")
db_disable_logging = gr.Checkbox(label="Disable Logging")
db_deterministic = gr.Checkbox(label="Deterministic")
db_ema_predict = gr.Checkbox(label="Use EMA for prediction")
db_lora_use_buggy_requires_grad = gr.Checkbox(label="LoRA use buggy requires grad")
db_noise_scheduler = gr.Dropdown(
label="Noise scheduler",
value="DDPM",
choices=[
"DDPM",
"DEIS",
"UniPC"
]
)
db_update_extension = gr.Button(
value="Update Extension and Restart"
)
with gr.Column(variant="panel"):
gr.HTML(value="Bucket Cropping")
db_crop_src_path = gr.Textbox(label="Source Path")
db_crop_dst_path = gr.Textbox(label="Dest Path")
db_crop_max_res = gr.Slider(
label="Max Res", value=512, step=64, maximum=2048
)
db_crop_bucket_step = gr.Slider(
label="Bucket Steps", value=8, step=8, maximum=512
)
db_crop_dry = gr.Checkbox(label="Dry Run")
db_start_crop = gr.Button("Start Cropping")
with gr.Column(variant="panel"):
with gr.Row():
with gr.Column(scale=1, min_width=110):
gr.HTML(value="<span class='hh'>Output</span>")
with gr.Column(scale=1, min_width=110):
db_check_progress_initial = gr.Button(
value=update_symbol,
elem_id="db_check_progress_initial",
visible=False,
)
# These two should be updated while doing things
db_active = gr.Checkbox(elem_id="db_active", value=False, visible=False)
ui_check_progress_initial = gr.Button(
value="Refresh", elem_id="ui_check_progress_initial", elem_classes="advBtn", size="sm"
)
db_status = gr.HTML(elem_id="db_status", value="")
db_progressbar = gr.HTML(elem_id="db_progressbar")
db_gallery = gr.Gallery(
label="Output", show_label=False, elem_id="db_gallery", columns=4
)
db_preview = gr.Image(elem_id="db_preview", visible=False)
db_prompt_list = gr.HTML(
elem_id="db_prompt_list", value="", visible=False
)
db_gallery_prompt = gr.HTML(elem_id="db_gallery_prompt", value="")
db_check_progress = gr.Button(
"Check Progress", elem_id=f"db_check_progress", visible=False
)
db_update_params = gr.Button(
"Update Parameters", elem_id="db_update_params", visible=False
)
db_launch_error = gr.HTML(
elem_id="launch_errors", visible=False, value=get_launch_errors
)
def check_toggles(
use_lora, class_gen_method, lr_scheduler, train_unet, scale_prior
):
stop_text_encoder = update_stop_tenc(train_unet)
(
show_ema,
use_lora_extended,
lora_save,
lora_rank,
lora_lr,
standard_lr,
lora_model,
_,
_,
_
) = disable_lora(use_lora)
(
lr_power,
lr_cycles,
lr_scale_pos,
lr_factor,
learning_rate_min,
lr_warmup_steps,
) = lr_scheduler_changed(lr_scheduler)
scheduler = class_gen_method_changed(class_gen_method)
loss_min, loss_tgt = toggle_loss_items(scale_prior)
return (
stop_text_encoder,
show_ema,
use_lora_extended,
lora_save,
lora_rank,
lora_lr,
lora_model,
scheduler,
lr_power,
lr_cycles,
lr_scale_pos,
lr_factor,
learning_rate_min,
lr_warmup_steps,
loss_min,
loss_tgt,
standard_lr
)
db_start_crop.click(
_js="db_start_crop",
fn=start_crop,
inputs=[
db_crop_src_path,
db_crop_dst_path,
db_crop_max_res,
db_crop_bucket_step,
db_crop_dry,
],
outputs=[db_status, db_gallery],
)
db_update_params.click(
fn=check_toggles,
inputs=[
db_use_lora,
db_class_gen_method,
db_lr_scheduler,
db_train_unet,
db_prior_loss_scale,
],
outputs=[
db_stop_text_encoder,
db_use_ema,
db_use_lora_extended,
lora_save_col,
lora_rank_col,
lora_lr_row,
lora_model_row,
db_scheduler,
db_lr_power,
db_lr_cycles,
db_lr_scale_pos,
db_lr_factor,
db_learning_rate_min,
db_lr_warmup_steps,
db_prior_loss_weight_min,
db_prior_loss_target,
standard_lr_row,
],
)
db_update_extension.click(fn=update_extension, inputs=[], outputs=[])
notification_webhook_test_btn.click(
fn=save_and_test_webhook,
inputs=[db_notification_webhook_url],
outputs=[db_status],
)
db_refresh_button.click(
fn=create_secret, inputs=[], outputs=[db_secret]
)
def update_stop_tenc(train_unet):
# If train unet enabled, read "hidden" value from stop_tenc and restore
if train_unet:
return gr.update(interactive=True)
else:
return gr.update(interactive=False)
db_train_unet.change(
fn=update_stop_tenc,
inputs=[db_train_unet],
outputs=[db_stop_text_encoder],
)
def toggle_full_mixed_precision(full_mixed_precision):
if full_mixed_precision != "fp16":
return gr.update(visible=False)
else:
return gr.update(visible=True)
db_mixed_precision.change(
fn=toggle_full_mixed_precision,
inputs=[db_mixed_precision],
outputs=[db_full_mixed_precision],
)
def update_model_options(model_type):
if model_type == "SDXL":
return gr.update(value=1024)
else:
return gr.update(value=512)
db_model_type_select.change(
fn=update_model_options,
inputs=[db_model_type_select],
outputs=[db_resolution]
)
db_clear_secret.click(fn=clear_secret, inputs=[], outputs=[db_secret])
# Elements to update when progress changes
progress_elements = [
db_active,
db_progressbar,
db_preview,
db_gallery,
db_status,
db_prompt_list,
ui_check_progress_initial,
]
db_check_progress.click(
fn=lambda: check_progress_call(),
show_progress=False,
inputs=[],
outputs=progress_elements
)
db_check_progress_initial.click(
fn=lambda: check_progress_call_initial(),
show_progress=False,
inputs=[],
outputs=progress_elements,
)
ui_check_progress_initial.click(
fn=lambda: check_progress_call(),
show_progress=False,
inputs=[],
outputs=progress_elements,
)
def format_updates():
updates = check_updates()
strings = []
if updates is not None:
for key, value in updates.items():
rev = key
title = value[0]
author = value[1]
date = value[2]
url = value[3]
title = f"<div class='commitDiv'><h3>{title}</h3><span>{author} - {date} - <a href='{url}'>{rev}</a><br></div>"
strings.append(title)
return "\n".join(strings)
with gr.Row(variant="panel", elem_id="change_modal"):
with gr.Row():
modal_title = gr.HTML("<h2>Changelog</h2>", elem_id="modal_title")
close_modal = gr.Button(value="X", elem_id="close_modal")
with gr.Row():
modal_release_notes = gr.HTML(
"<h3><a href='https://github.com/d8ahazard/sd_dreambooth_extension/releases/latest'>Release notes</a></h3>",
elem_id="modal_notes",
)
with gr.Column():
change_log = gr.HTML(format_updates(), elem_id="change_log")
advanced_elements = [
db_snapshot_row,
db_create_from_hub,
db_new_model_extract_ema,
db_train_unfrozen,
db_use_ema,
db_freeze_clip_normalization,
db_full_mixed_precision,
db_offset_noise,
db_weight_decay,
db_tenc_weight_decay,
db_tenc_grad_clip_norm,
db_min_snr_gamma,
db_use_dream,
db_dream_detail_preservation,
db_freeze_spectral_norm,
db_pad_tokens,
db_strict_tokens,
db_max_token_length,
db_epoch_pause_frequency,
db_epoch_pause_time,
db_batch_size_view,
db_lr_scheduler,
db_lr_warmup_steps,
db_hflip,
db_prior_loss_view,
db_misc_view,
db_hook_view,
db_save_tab,
db_generate_tab,
db_test_tab,
db_dynamic_img_norm,
db_tomesd,
db_split_loss,
db_disable_class_matching,
db_disable_logging,
db_deterministic,
db_ema_predict,
db_lora_use_buggy_requires_grad,
db_noise_scheduler,
c1_class_guidance_scale,
c1_class_infer_steps,
c1_save_sample_negative_prompt,
c1_sample_seed,
c1_save_guidance_scale,
c1_save_infer_steps,
c2_class_guidance_scale,
c2_class_infer_steps,
c2_save_sample_negative_prompt,
c2_sample_seed,
c2_save_guidance_scale,
c2_save_infer_steps,
c3_class_guidance_scale,
c3_class_infer_steps,
c3_save_sample_negative_prompt,
c3_sample_seed,
c3_save_guidance_scale,
c3_save_infer_steps,
c4_class_guidance_scale,
c4_class_infer_steps,
c4_save_sample_negative_prompt,
c4_sample_seed,
c4_save_guidance_scale,
c4_save_infer_steps,
]
def toggle_advanced():
global show_advanced
show_advanced = False if show_advanced else True
outputs = [gr.update(visible=True), gr.update(visible=False)]
print(f"Advanced elements visible: {show_advanced}")
for _ in advanced_elements:
outputs.append(gr.update(visible=show_advanced))
return outputs
# Merge db_show advanced, db_hide_advanced, and advanced elements into one list
db_show_advanced.click(
fn=toggle_advanced,
inputs=[],
outputs=[db_hide_advanced, db_show_advanced, *advanced_elements]
)
db_hide_advanced.click(
fn=toggle_advanced,
inputs=[],
outputs=[db_show_advanced, db_hide_advanced, *advanced_elements]
)
global preprocess_params
# preprocess_params = [
# db_preprocess_path,
# db_preprocess_recursive
# ]
#
# db_preprocess_path.change(
# fn=check_preprocess_path,
# inputs=[db_preprocess_path, db_preprocess_recursive],
# outputs=[db_status, db_gallery]
# )
db_gallery.select(load_image_caption, None, db_status)
global params_to_save
global params_to_load
# List of all the things that we need to save
# db_model_name must be first due to save_config() parsing
params_to_save = [
db_weight_decay,
db_attention,
db_cache_latents,
db_clip_skip,
db_concepts_path,
db_custom_model_name,
db_deterministic,
db_disable_class_matching,
db_disable_logging,
db_ema_predict,
db_tomesd,
db_epoch_pause_frequency,
db_epoch_pause_time,
db_epochs,
db_freeze_clip_normalization,
db_full_mixed_precision,
db_gradient_accumulation_steps,
db_gradient_checkpointing,
db_gradient_set_to_none,
db_graph_smoothing,
db_half_model,
db_hflip,
db_infer_ema,
db_learning_rate,
db_learning_rate_min,
db_lora_learning_rate,
db_lora_model_name,
db_lora_txt_learning_rate,
db_lora_txt_rank,
db_lora_unet_rank,
db_lora_use_buggy_requires_grad,
db_lora_weight,
db_lr_cycles,
db_lr_factor,
db_lr_power,
db_lr_scale_pos,
db_lr_scheduler,
db_lr_warmup_steps,
db_max_token_length,
db_min_snr_gamma,
db_use_dream,
db_dream_detail_preservation,
db_freeze_spectral_norm,
db_mixed_precision,
db_model_name,
db_model_path,
db_noise_scheduler,
db_num_train_epochs,
db_offset_noise,
db_optimizer,
db_pad_tokens,
db_pretrained_vae_name_or_path,
db_prior_loss_scale,
db_prior_loss_target,
db_prior_loss_weight,
db_prior_loss_weight_min,
db_resolution,
db_revision,
db_sample_batch_size,
db_sanity_prompt,
db_sanity_seed,
db_save_ckpt_after,
db_save_ckpt_cancel,
db_save_ckpt_during,
db_save_ema,
db_save_embedding_every,
db_save_lora_after,
db_save_lora_cancel,
db_save_lora_during,
db_save_lora_for_extra_net,
db_save_preview_every,
db_save_safetensors,
db_save_state_after,
db_save_state_cancel,
db_save_state_during,
db_scheduler,
db_shared_diffusers_path,
db_shuffle_tags,
db_snapshot,
db_split_loss,
db_src,
db_stop_text_encoder,
db_strict_tokens,
db_dynamic_img_norm,
db_tenc_grad_clip_norm,
db_tenc_weight_decay,
db_train_batch_size,
db_train_imagic,
db_train_unet,
db_train_unfrozen,
db_txt_learning_rate,
db_use_concepts,
db_use_ema,
db_use_lora,
db_use_lora_extended,
db_use_shared_src,
db_use_subdir,
c1_class_data_dir,
c1_class_guidance_scale,
c1_class_infer_steps,
c1_class_negative_prompt,
c1_class_prompt,
c1_class_token,
c1_instance_data_dir,
c1_instance_prompt,
c1_instance_token,
c1_n_save_sample,
c1_num_class_images_per,
c1_sample_seed,
c1_save_guidance_scale,
c1_save_infer_steps,
c1_save_sample_negative_prompt,
c1_save_sample_prompt,
c1_save_sample_template,
c2_class_data_dir,
c2_class_guidance_scale,
c2_class_infer_steps,
c2_class_negative_prompt,
c2_class_prompt,
c2_class_token,
c2_instance_data_dir,
c2_instance_prompt,
c2_instance_token,
c2_n_save_sample,
c2_num_class_images_per,
c2_sample_seed,
c2_save_guidance_scale,
c2_save_infer_steps,
c2_save_sample_negative_prompt,
c2_save_sample_prompt,
c2_save_sample_template,
c3_class_data_dir,
c3_class_guidance_scale,
c3_class_infer_steps,
c3_class_negative_prompt,
c3_class_prompt,
c3_class_token,
c3_instance_data_dir,
c3_instance_prompt,
c3_instance_token,
c3_n_save_sample,
c3_num_class_images_per,
c3_sample_seed,
c3_save_guidance_scale,
c3_save_infer_steps,
c3_save_sample_negative_prompt,
c3_save_sample_prompt,
c3_save_sample_template,
c4_class_data_dir,
c4_class_guidance_scale,
c4_class_infer_steps,
c4_class_negative_prompt,
c4_class_prompt,
c4_class_token,
c4_instance_data_dir,
c4_instance_prompt,
c4_instance_token,
c4_n_save_sample,
c4_num_class_images_per,
c4_sample_seed,
c4_save_guidance_scale,
c4_save_infer_steps,
c4_save_sample_negative_prompt,
c4_save_sample_prompt,
c4_save_sample_template,
]
for element in params_to_save:
setattr(element, "do_not_save_to_config", True)
# Do not load these values when 'load settings' is clicked
params_to_exclude = [
db_model_name,
db_epochs,
db_model_path,
db_revision,
db_src,
db_model_type,
db_shared_diffusers_path,
]
# Populate by the below method and handed out to other elements
params_to_load = []
save_keys = []
ui_keys = []
for param in params_to_save:
var_name = [var_name for var_name, var in locals().items() if var is param]
save_keys.append(var_name[0])
if param not in params_to_exclude:
ui_keys.append(var_name[0])
params_to_load.append(param)
ui_keys.append("db_status")
params_to_load.append(db_status)
from dreambooth.dataclasses import db_config
db_config.save_keys = save_keys
db_config.ui_keys = ui_keys
db_save_params.click(
_js="check_save", fn=save_config, inputs=params_to_save, outputs=[]
)
db_load_params.click(
_js="db_start_load_params",
fn=load_params,
inputs=[db_model_name],
outputs=params_to_load,
)
def toggle_new_rows(create_from):
return gr.update(visible=create_from), gr.update(visible=not create_from)
def toggle_loss_items(scale):
return gr.update(visible=scale), gr.update(visible=scale)
db_create_from_hub.change(
fn=toggle_new_rows,
inputs=[db_create_from_hub],
outputs=[hub_row, local_row],
)
def toggle_shared_row(row):
return gr.update(visible=row), gr.update(value="")
db_use_shared_src.change(
fn=toggle_shared_row,
inputs=[db_use_shared_src],
outputs=[shared_row, db_new_model_shared_src],
)
db_prior_loss_scale.change(
fn=toggle_loss_items,
inputs=[db_prior_loss_scale],
outputs=[db_prior_loss_weight_min, db_prior_loss_target],
)
def disable_lora(x):
use_ema = gr.update(interactive=not x)
use_lora_extended = gr.update(visible=False)
lora_save = gr.update(visible=x)
lora_rank = gr.update(visible=x)
lora_lr = gr.update(visible=x)
standard_lr = gr.update(visible=not x)
lora_model = gr.update(visible=x)
if x:
save_during =gr.update(label="Save LORA during training")
save_after = gr.update(label="Save LORA after training")
save_cancel = gr.update(label="Save LORA on cancel")
else:
save_during = gr.update(label="Save .safetensors during training")
save_after = gr.update(label="Save .safetensors after training")
save_cancel = gr.update(label="Save .safetensors on cancel")
return (
use_ema,
use_lora_extended,
lora_save,
lora_rank,
lora_lr,
standard_lr,
lora_model,
save_during,
save_after,
save_cancel
)
def lr_scheduler_changed(sched):
show_scale_pos = gr.update(visible=False)
show_min_lr = gr.update(visible=False)
show_lr_factor = gr.update(visible=False)
show_lr_warmup = gr.update(visible=False)
show_lr_power = gr.update(visible=sched == "polynomial")
show_lr_cycles = gr.update(visible=sched == "cosine_with_restarts")
scale_scheds = [
"constant",
"linear",
"cosine_annealing",
"cosine_annealing_with_restarts",
]
if sched in scale_scheds:
show_scale_pos = gr.update(visible=True)
else:
show_lr_warmup = gr.update(visible=True)
if sched in ["cosine_annealing", "cosine_annealing_with_restarts"]:
show_min_lr = gr.update(visible=True)
if sched in ["linear", "constant"]:
show_lr_factor = gr.update(visible=True)
return (
show_lr_power,
show_lr_cycles,
show_scale_pos,
show_lr_factor,
show_min_lr,
show_lr_warmup,
)
def optimizer_changed(opti):
show_adapt = "adapt" in opti
adaptation_lr = gr.update(visible=show_adapt)
return adaptation_lr
def class_gen_method_changed(method):
show_scheduler = method == "Native Diffusers"
scheduler = gr.update(visible=show_scheduler)
return scheduler
db_use_lora.change(
fn=disable_lora,
inputs=[db_use_lora],
outputs=[
db_use_ema,
db_use_lora_extended,
lora_save_col,
lora_rank_col,
lora_lr_row,
standard_lr_row,
lora_model_row,
db_save_ckpt_during,
db_save_ckpt_after,
db_save_ckpt_cancel
],
)
db_lr_scheduler.change(
fn=lr_scheduler_changed,
inputs=[db_lr_scheduler],
outputs=[
db_lr_power,
db_lr_cycles,
db_lr_scale_pos,
db_lr_factor,
db_learning_rate_min,
db_lr_warmup_steps,
],
)
db_class_gen_method.change(
fn=class_gen_method_changed,
inputs=[db_class_gen_method],
outputs=[db_scheduler],
)
db_model_name.change(
_js="clear_loaded",
fn=load_model_params,
inputs=[db_model_name],
outputs=[
db_model_info,
db_model_path,
db_revision,
db_epochs,
db_model_type,
db_has_ema,
db_src,
db_shared_diffusers_path,
db_snapshot,
db_lora_model_name,
db_status,
],
)
db_use_concepts.change(
fn=lambda x: {concept_tab: gr_show(x is True)},
inputs=[db_use_concepts],
outputs=[concept_tab],
)
db_generate_graph.click(
_js="db_start_logs",
fn=log_parser.parse_logs,
inputs=[db_model_name, gr.Checkbox(value=True, visible=False)],
outputs=[db_gallery, db_prompt_list],
)
db_debug_buckets.click(
_js="db_start_buckets",
fn=debug_buckets,
inputs=[db_model_name, db_bucket_epochs, db_bucket_batch],
outputs=[db_status, db_status],
)
db_performance_wizard.click(
fn=performance_wizard,
_js="db_start_pwizard",
inputs=[db_model_name],
outputs=[
db_attention,
db_gradient_checkpointing,
db_gradient_accumulation_steps,
db_mixed_precision,
db_cache_latents,
db_optimizer,
db_sample_batch_size,
db_train_batch_size,
db_stop_text_encoder,
db_use_lora,
db_use_ema,
db_save_preview_every,
db_save_embedding_every,
db_status,
],
)
db_generate_sample.click(
fn=wrap_gpu_call(generate_samples),
_js="db_start_sample",
inputs=[
db_model_name,
db_sample_prompt,
db_sample_prompt_file,
db_sample_negative,
db_sample_width,
db_sample_height,
db_num_samples,
db_sample_batch_size,
db_sample_seed,
db_sample_steps,
db_sample_scale,
db_sample_txt2img,
db_scheduler,
db_swap_faces,
db_swap_prompt,
db_swap_negative,
db_swap_steps,
db_swap_batch,
],
outputs=[db_gallery, db_prompt_list, db_status],
)
db_generate_checkpoint.click(
_js="db_start_checkpoint",
fn=wrap_gpu_call(ui_gen_ckpt),
inputs=[db_model_name],
outputs=[db_status],
)
def set_gen_ckpt():
status.do_save_model = True
def set_gen_sample():
status.do_save_samples = True
db_generate_checkpoint_during.click(fn=set_gen_ckpt, inputs=[], outputs=[])
db_train_sample.click(fn=set_gen_sample, inputs=[], outputs=[])
db_create_model.click(
fn=wrap_gpu_call(create_model),
_js="db_start_create",
inputs=[
db_new_model_name,
db_new_model_src,
db_new_model_shared_src,
db_create_from_hub,
db_new_model_url,
db_new_model_token,
db_new_model_extract_ema,
db_train_unfrozen,
db_model_type_select
],
outputs=[
db_model_name,
db_model_path,
db_revision,
db_epochs,
db_src,
db_shared_diffusers_path,
db_has_ema,
db_model_type,
db_resolution,
db_status,
],
)
db_train_model.click(
fn=wrap_gpu_call(start_training),
_js="db_start_train",
inputs=[db_model_name, db_class_gen_method],
outputs=[db_lora_model_name, db_revision, db_epochs, db_gallery, db_status],
)
db_generate_classes.click(
_js="db_start_classes",
fn=wrap_gpu_call(ui_classifiers),
inputs=[db_model_name, db_class_gen_method],
outputs=[db_gallery, db_status],
)
db_cancel.click(
fn=lambda: status.interrupt(),
inputs=[],
outputs=[],
)
return ((dreambooth_interface, "Dreambooth", "dreambooth_v2"),)
def build_concept_panel(concept: int):
with gr.Tab(label="Instance Images"):
instance_data_dir = gr.Textbox(
label="Directory",
placeholder="Path to directory with input images",
elem_id=f"idd{concept}",
)
instance_prompt = gr.Textbox(label="Prompt", value="[filewords]")
gr.HTML(value="Use [filewords] here to read prompts from caption files/filename, or a prompt to describe your training images.<br>"
"If using [filewords], your instance and class tokens will be inserted into the prompt as necessary for training.", elem_classes="hintHtml")
instance_token = gr.Textbox(label="Instance Token")
gr.HTML(value="If using [filewords] above, this is the unique word used for your subject, like 'fydodog' or 'ohwx'.",
elem_classes="hintHtml")
class_token = gr.Textbox(label="Class Token")
gr.HTML(value="If using [filewords] above, this is the generic word used for your subject, like 'dog' or 'person'.",
elem_classes="hintHtml")
with gr.Tab(label="Class Images"):
class_data_dir = gr.Textbox(
label="Directory",
placeholder="(Optional) Path to directory with "
"classification/regularization images",
elem_id=f"cdd{concept}",
)
class_prompt = gr.Textbox(label="Prompt", value="[filewords]")
gr.HTML(
value="Use [filewords] here to read prompts from caption files/filename, or a prompt to describe your training images.<br>"
"If using [filewords], your class token will be inserted into the file prompts if it is not found.",
elem_classes="hintHtml")
class_negative_prompt = gr.Textbox(
label="Negative Prompt"
)
num_class_images_per = gr.Slider(
label="Class Images Per Instance Image", value=0, precision=0
)
gr.HTML(value="For every instance image, this many classification images will be used/generated. Leave at 0 to disable.",
elem_classes="hintHtml")
class_guidance_scale = gr.Slider(
label="Classification CFG Scale", value=7.5, maximum=12, minimum=1, step=0.1
)
class_infer_steps = gr.Slider(
label="Classification Steps", value=40, minimum=10, maximum=200, step=1
)
with gr.Tab(label="Sample Images"):
save_sample_prompt = gr.Textbox(
label="Sample Image Prompt",
value='[filewords]'
)
gr.HTML(
value="A prompt to generate samples from, or use [filewords] here to randomly select prompts from the existing instance prompt(s).<br>"
"If using [filewords], your instance token will be inserted into the file prompts if it is not found.",
elem_classes="hintHtml")
save_sample_negative_prompt = gr.Textbox(
label="Sample Negative Prompt"
)
sample_template = gr.Textbox(
label="Sample Prompt Template File",
placeholder="Enter the path to a txt file containing sample prompts.",
)
gr.HTML(value="When enabled the above prompt and negative prompt will be ignored.",
elem_classes="hintHtml")
n_save_sample = gr.Slider(
label="Number of Samples to Generate", value=1, maximum=100, step=1
)
sample_seed = gr.Number(label="Sample Seed", value=-1, precision=0)
save_guidance_scale = gr.Slider(
label="Sample CFG Scale", value=7.5, maximum=12, minimum=1, step=0.1
)
save_infer_steps = gr.Slider(
label="Sample Steps", value=20, minimum=10, maximum=200, step=1
)
return [
instance_data_dir,
class_data_dir,
instance_prompt,
class_prompt,
save_sample_prompt,
sample_template,
instance_token,
class_token,
num_class_images_per,
class_negative_prompt,
class_guidance_scale,
class_infer_steps,
save_sample_negative_prompt,
n_save_sample,
sample_seed,
save_guidance_scale,
save_infer_steps,
]
script_callbacks.on_ui_tabs(on_ui_tabs)