170 lines
9.3 KiB
Python
170 lines
9.3 KiB
Python
import os
|
|
|
|
from modules.hypernetworks.ui import keys
|
|
import modules.scripts as scripts
|
|
from modules import script_callbacks, shared, sd_hijack
|
|
import gradio as gr
|
|
|
|
from modules.paths import script_path
|
|
from modules.ui import create_refresh_button, gr_show
|
|
import patches.textual_inversion as textual_inversion
|
|
import patches.ui as ui
|
|
import patches.shared as shared_patch
|
|
import patches.external_pr.ui as external_patch_ui
|
|
from webui import wrap_gradio_gpu_call
|
|
|
|
setattr(shared.opts,'pin_memory', False)
|
|
|
|
|
|
def create_training_tab(params: script_callbacks.UiTrainTabParams = None):
|
|
with gr.Tab(label="Train_Beta") as train_beta:
|
|
gr.HTML(
|
|
value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
|
with gr.Row():
|
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork",
|
|
choices=[x for x in shared.hypernetworks.keys()])
|
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks,
|
|
lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])},
|
|
"refresh_train_hypernetwork_name")
|
|
with gr.Row():
|
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate',
|
|
placeholder="Hypernetwork Learning rate", value="0.00001")
|
|
use_beta_scheduler_checkbox = gr.Checkbox(
|
|
label='Show advanced learn rate scheduler options(for Hypernetworks)')
|
|
with gr.Row(visible=False) as beta_scheduler_options:
|
|
use_beta_scheduler = gr.Checkbox(label='Uses CosineAnnealingWarmRestarts Scheduler')
|
|
beta_repeat_epoch = gr.Textbox(label='Epoch for cycle', placeholder="Cycles every nth epoch", value="4000")
|
|
epoch_mult = gr.Textbox(label='Epoch multiplier per cycle', placeholder="Cycles length multiplier every cycle", value="1")
|
|
warmup = gr.Textbox(label='Warmup step per cycle', placeholder="CosineAnnealing lr increase step", value="1")
|
|
min_lr = gr.Textbox(label='Minimum learning rate for beta scheduler',
|
|
placeholder="restricts decay value, but does not restrict gamma rate decay",
|
|
value="1e-7")
|
|
gamma_rate = gr.Textbox(label='Separate learning rate decay for ExponentialLR',
|
|
placeholder="Value should be in (0-1]", value="1")
|
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs",
|
|
value="textual_inversion")
|
|
template_file = gr.Textbox(label='Prompt template file',
|
|
value=os.path.join(script_path, "textual_inversion_templates",
|
|
"style_filewords.txt"))
|
|
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
|
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500,
|
|
precision=0)
|
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable',
|
|
value=500, precision=0)
|
|
preview_from_txt2img = gr.Checkbox(
|
|
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
|
|
|
with gr.Row():
|
|
interrupt_training = gr.Button(value="Interrupt")
|
|
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
|
ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False)
|
|
ti_outcome = gr.HTML(elem_id="ti_error2", value="")
|
|
use_beta_scheduler_checkbox.change(
|
|
fn=lambda show: gr_show(show),
|
|
inputs=[use_beta_scheduler_checkbox],
|
|
outputs=[beta_scheduler_options],
|
|
)
|
|
interrupt_training.click(
|
|
fn=lambda: shared.state.interrupt(),
|
|
inputs=[],
|
|
outputs=[],
|
|
)
|
|
train_hypernetwork.click(
|
|
fn=wrap_gradio_gpu_call(ui.train_hypernetwork_ui, extra_outputs=[gr.update()]),
|
|
_js="start_training_textual_inversion",
|
|
inputs=[
|
|
train_hypernetwork_name,
|
|
hypernetwork_learn_rate,
|
|
batch_size,
|
|
dataset_directory,
|
|
log_directory,
|
|
training_width,
|
|
training_height,
|
|
steps,
|
|
create_image_every,
|
|
save_embedding_every,
|
|
template_file,
|
|
preview_from_txt2img,
|
|
*params.txt2img_preview_params,
|
|
use_beta_scheduler,
|
|
beta_repeat_epoch,
|
|
epoch_mult,
|
|
warmup,
|
|
min_lr,
|
|
gamma_rate
|
|
],
|
|
outputs=[
|
|
ti_output,
|
|
ti_outcome,
|
|
]
|
|
)
|
|
return [(train_beta, "Train_beta", "train_beta")]
|
|
|
|
def create_extension_tab(params=None):
|
|
with gr.Tab(label="Create Beta hypernetwork") as create_beta:
|
|
new_hypernetwork_name = gr.Textbox(label="Name")
|
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"],
|
|
choices=["768", "320", "640", "1280"])
|
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure",
|
|
placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
|
new_hypernetwork_activation_func = gr.Dropdown(value="linear",
|
|
label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)",
|
|
choices=keys)
|
|
new_hypernetwork_initialization_option = gr.Dropdown(value="Normal",
|
|
label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise",
|
|
choices=["Normal", "KaimingUniform", "KaimingNormal",
|
|
"XavierUniform", "XavierNormal"])
|
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
|
new_hypernetwork_use_dropout = gr.Checkbox(
|
|
label="Use dropout. Might improve training when dataset is small / limited.")
|
|
new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0",
|
|
label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15",
|
|
placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
|
|
optional_info = gr.Textbox("", label="Optional information about Hypernetwork", placeholder="Training information, dateset, etc")
|
|
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
|
|
|
with gr.Row():
|
|
with gr.Column(scale=3):
|
|
gr.HTML(value="")
|
|
|
|
with gr.Column():
|
|
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
|
|
ti_output = gr.Text(elem_id="ti_output2", value="", show_label=False)
|
|
ti_outcome = gr.HTML(elem_id="ti_error2", value="")
|
|
create_hypernetwork.click(
|
|
fn=ui.create_hypernetwork,
|
|
inputs=[
|
|
new_hypernetwork_name,
|
|
new_hypernetwork_sizes,
|
|
overwrite_old_hypernetwork,
|
|
new_hypernetwork_layer_structure,
|
|
new_hypernetwork_activation_func,
|
|
new_hypernetwork_initialization_option,
|
|
new_hypernetwork_add_layer_norm,
|
|
new_hypernetwork_use_dropout,
|
|
new_hypernetwork_dropout_structure,
|
|
optional_info
|
|
],
|
|
outputs=[
|
|
new_hypernetwork_name,
|
|
ti_output,
|
|
ti_outcome,
|
|
]
|
|
)
|
|
return [(create_beta, "Create_beta", "create_beta")]
|
|
|
|
#script_callbacks.on_ui_train_tabs(create_training_tab) # Deprecate Beta Training
|
|
script_callbacks.on_ui_train_tabs(create_extension_tab)
|
|
script_callbacks.on_ui_train_tabs(external_patch_ui.on_train_gamma_tab)
|
|
|
|
class Script(scripts.Script):
|
|
def title(self):
|
|
return "Hypernetwork Monkey Patch"
|
|
|
|
def show(self, _):
|
|
return scripts.AlwaysVisible
|