automatic/modules/ui_train.py

380 lines
25 KiB
Python

import os
import gradio as gr
from modules import script_callbacks, shared
from modules.ui_common import create_refresh_button
from modules.ui_sections import create_sampler_inputs
from modules.call_queue import wrap_gradio_gpu_call
def create_ui():
from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
dummy_component = gr.Label(visible=False)
with gr.Row(elem_id="train_tab"):
with gr.Column(elem_id='train_output_container', scale=1):
train_output = gr.Text(elem_id="train_output", value="", show_label=False)
gr.Gallery(label='Output', show_label=False, elem_id='train_gallery', columns=1)
gr.HTML(elem_id="train_progress", value="")
train_outcome = gr.HTML(elem_id="train_error", value="")
with gr.Row(visible=True) as action_pp:
process_run = gr.Button(value="Preprocess", variant='primary')
process_stop = gr.Button("Stop")
with gr.Row(visible=False) as action_ti:
ti_train = gr.Button(value="Train embedding", variant='primary')
ti_stop = gr.Button(value="Stop")
with gr.Row(visible=False) as action_hn:
hn_train = gr.Button(value="Train hypernetwork", variant='primary')
hn_stop = gr.Button(value="Stop")
with gr.Column(elem_id='train_input_container', scale=3):
with gr.Tabs(elem_id="train_tabs"):
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}
def train_tab_change(tab):
if tab == 'ti':
return gr_show(False), gr_show(True), gr_show(False)
elif tab == 'hn':
return gr_show(False), gr_show(False), gr_show(True)
elif tab == 'pr':
return gr_show(True), gr_show(False), gr_show(False)
else:
return gr_show(False), gr_show(False), gr_show(False)
### preview tab
with gr.Tab(label="Preview settings", id="train_preview_tab") as tab_preview:
tab_preview.select(fn=lambda: train_tab_change('pr'), inputs=[], outputs=[action_pp, action_ti, action_hn])
prompt = gr.Textbox(label="Prompt", value="", placeholder="Prompt to be used for previews", lines=2)
negative = gr.Textbox(label="Negative prompt", value="", placeholder="Negative prompt to be used for previews", lines=2)
steps, sampler_index = create_sampler_inputs('train', accordion=False)
cfg_scale = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label='CFG scale', value=6.0)
seed = gr.Number(label='Initial seed', value=-1)
with gr.Row():
width = gr.Slider(minimum=64, maximum=8192, step=8, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=8192, step=8, label="Height", value=512)
txt2img_preview_params = [prompt, negative, steps, sampler_index, cfg_scale, seed, width, height]
### preprocess tab
with gr.Tab(label="Preprocess images", id="preprocess_images") as tab_preprocess:
tab_preprocess.select(fn=lambda: train_tab_change('pp'), inputs=[], outputs=[action_pp, action_ti, action_hn])
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
with gr.Row():
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
preprocess_txt_action = gr.Dropdown(label='Existing caption text action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Box():
gr.HTML('<h2>Preprocessing steps</h2>')
process_keep_original_size = gr.Checkbox(label='Keep original size')
process_keep_channels = gr.Checkbox(label='Keep original image channels')
process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images')
process_focal_crop = gr.Checkbox(label='Auto focal point crop')
process_multicrop = gr.Checkbox(label='Auto-sized crop')
process_caption_only = gr.Checkbox(label='Create captions only')
process_caption = gr.Checkbox(label='Create BLIP captions')
process_caption_deepbooru = gr.Checkbox(label='Create Deepbooru captions')
with gr.Row(visible=False) as process_split_extra_row:
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
with gr.Row(visible=False) as process_focal_crop_row:
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_debug = gr.Checkbox(label='Create debug image')
with gr.Column(visible=False) as process_multicrop_col:
gr.HTML('<h2>Each image is center-cropped with an automatically chosen width and height</h2>')
with gr.Row():
process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384)
process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768)
with gr.Row():
process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64)
process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640)
with gr.Row():
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective")
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1)
from modules.textual_inversion import ui
process_split.change(fn=lambda show: gr_show(show), inputs=[process_split], outputs=[process_split_extra_row])
process_focal_crop.change(fn=lambda show: gr_show(show), inputs=[process_focal_crop], outputs=[process_focal_crop_row])
process_multicrop.change(fn=lambda show: gr_show(show), inputs=[process_multicrop], outputs=[process_multicrop_col])
process_stop.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])
process_run.click(
fn=wrap_gradio_gpu_call(ui.preprocess, extra_outputs=[gr.update()]),
_js="startTrainMonitor",
inputs=[
dummy_component,
process_src,
process_dst,
process_width,
process_height,
preprocess_txt_action,
process_keep_original_size,
process_keep_channels,
process_flip,
process_split,
process_caption_only,
process_caption,
process_caption_deepbooru,
process_split_threshold,
process_overlap_ratio,
process_focal_crop,
process_focal_crop_face_weight,
process_focal_crop_entropy_weight,
process_focal_crop_edges_weight,
process_focal_crop_debug,
process_multicrop,
process_multicrop_mindim,
process_multicrop_maxdim,
process_multicrop_minarea,
process_multicrop_maxarea,
process_multicrop_objective,
process_multicrop_threshold,
],
outputs=[
train_output,
train_outcome,
],
)
### train embedding tab
if shared.backend == shared.Backend.ORIGINAL:
from modules import sd_hijack
with gr.Tab(label="Train embedding", id="train_embedding_tab") as tab_ti:
tab_ti.select(fn=lambda: train_tab_change('ti'), inputs=[], outputs=[action_pp, action_ti, action_hn])
def get_textual_inversion_template_names():
return sorted(textual_inversion.textual_inversion_templates)
gr.HTML('<h2>Select existing embedding to continue training or create a new one</h2>')
with gr.Row():
with gr.Column():
with gr.Row():
ti_name = gr.Dropdown(label='Select embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(ti_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Column():
ti_new_name = gr.Textbox(label="Create emebedding")
ti_init_text = gr.Textbox(label="Initialization text", value="*")
ti_vectors = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
ti_overwrite = gr.Checkbox(value=False, label="Overwrite Old Embedding")
with gr.Row():
ti_create = gr.Button(value="Create embedding", variant='secondary')
with gr.Box():
gr.HTML('<h2>Training parameters</h2>')
ti_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
with gr.Row():
ti_clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
ti_clip_grad_value = gr.Number(label="Gradient clip value", value=0.1)
ti_batch_size = gr.Number(label='Batch size', value=1, precision=0)
ti_gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
ti_steps = gr.Number(label='Max steps', value=1000, precision=0)
with gr.Box():
gr.HTML('<h2>Training images</h2>')
ti_dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
with gr.Row():
ti_varsize = gr.Checkbox(label="Do not resize images", value=False)
ti_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
ti_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
ti_use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False)
with gr.Box():
gr.HTML('<h2>Dataset processing</h2>')
with gr.Row():
ti_template = gr.Dropdown(label='Prompt template', value="style_filewords.txt", choices=get_textual_inversion_template_names())
create_refresh_button(ti_template, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
ti_shuffle = gr.Checkbox(label="Shuffle tags", value=False)
ti_tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts", value=0)
ti_latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
with gr.Box():
gr.HTML('<h2>Training outputs</h2>')
with gr.Row():
ti_create_every = gr.Number(label='Create interim images', value=500, precision=0)
ti_save_every = gr.Number(label='Create interim embeddings', value=500, precision=0)
ti_save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
ti_preview_from_txt2img = gr.Checkbox(label='Use current settings for previews', value=False)
ti_log_directory = gr.Textbox(label='Log directory', placeholder="Defaults to train/log/embedding", value="")
ti_stop.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])
ti_create.click(
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
ti_new_name,
ti_init_text,
ti_vectors,
ti_overwrite,
],
outputs=[
ti_name,
train_output,
train_outcome,
]
)
ti_train.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="startTrainMonitor",
inputs=[
dummy_component,
ti_name,
ti_learn_rate,
ti_batch_size,
ti_gradient_step,
ti_dataset_directory,
ti_log_directory,
ti_width,
ti_height,
ti_varsize,
ti_steps,
ti_clip_grad_mode,
ti_clip_grad_value,
ti_shuffle,
ti_tag_drop_out,
ti_latent_sampling_method,
ti_use_weight,
ti_create_every,
ti_save_every,
ti_template,
ti_save_image_with_stored_embedding,
ti_preview_from_txt2img,
*txt2img_preview_params,
],
outputs=[
train_output,
train_outcome,
]
)
### train hypernetwork tab
if shared.backend == shared.Backend.ORIGINAL:
from modules import sd_hijack
with gr.Tab(label="Train hypernetwork", id="train_hypernetwork_tab") as tab_hn:
tab_hn.select(fn=lambda: train_tab_change('hn'), inputs=[], outputs=[action_pp, action_ti, action_hn])
gr.HTML('<h2>Select existing hypernetwork to continue training or create a new one</h2>')
with gr.Row():
with gr.Column():
with gr.Row():
hn_name = gr.Dropdown(label='Hypernetwork', choices=sorted(shared.hypernetworks))
create_refresh_button(hn_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
with gr.Column():
hn_new_name = gr.Textbox(label="Name")
hn_new_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"])
hn_new_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
with gr.Row():
hn_new_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
hn_new_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
hn_new_add_layer_norm = gr.Checkbox(label="Add layer normalization")
hn_new_use_dropout = gr.Checkbox(label="Use dropout")
hn_new_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
hn_overwrite = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
with gr.Row():
hn_create = gr.Button(value="Create hypernetwork", variant='secondary')
with gr.Box():
gr.HTML('<h2>Training parameters</h2>')
hn_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
with gr.Row():
hn_clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
hn_clip_grad_value = gr.Number(label="Gradient clip value", value=0.1)
hn_batch_size = gr.Number(label='Batch size', value=1, precision=0)
hn_gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
hn_steps = gr.Number(label='Max steps', value=1000, precision=0)
with gr.Box():
gr.HTML('<h2>Training images</h2>')
hn_dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
with gr.Row():
hn_varsize = gr.Checkbox(label="Do not resize images", value=False)
hn_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
hn_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
hn_use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False)
with gr.Box():
gr.HTML('<h2>Dataset processing</h2>')
with gr.Row():
hn_template = gr.Dropdown(label='Prompt template', value="style_filewords.txt", choices=get_textual_inversion_template_names())
create_refresh_button(hn_template, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
hn_shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
hn_tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts", value=0)
hn_latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
with gr.Box():
gr.HTML('<h2>Training outputs</h2>')
with gr.Row():
hn_create_every = gr.Number(label='Create interim images', value=500, precision=0)
hn_save_every = gr.Number(label='Create interim hypernetworks', value=500, precision=0)
hn_preview_from_txt2img = gr.Checkbox(label='Use current settings for previews', value=False)
hn_log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value=f"{os.path.join('cmd_opts.data_dir', 'train/log/embeddings')}")
hn_stop.click(fn=lambda: shared.state.interrupt(), inputs=[], outputs=[])
hn_create.click(
fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[
hn_new_name,
hn_new_sizes,
hn_overwrite,
hn_new_layer_structure,
hn_new_activation_func,
hn_new_initialization_option,
hn_new_add_layer_norm,
hn_new_use_dropout,
hn_new_dropout_structure
],
outputs=[
hn_name,
train_output,
train_outcome,
]
)
hn_train.click(
fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
_js="startTrainMonitor",
inputs=[
dummy_component,
hn_name,
hn_learn_rate,
hn_batch_size,
hn_gradient_step,
hn_dataset_directory,
hn_log_directory,
hn_width,
hn_height,
hn_varsize,
hn_steps,
hn_clip_grad_mode,
hn_clip_grad_value,
hn_shuffle_tags,
hn_tag_drop_out,
hn_latent_sampling_method,
hn_use_weight,
hn_create_every,
hn_save_every,
hn_template,
hn_preview_from_txt2img,
*txt2img_preview_params,
],
outputs=[
train_output,
train_outcome,
]
)
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
script_callbacks.ui_train_tabs_callback(params)