Implement required GUI changes

pull/1688/head
bmaltais 2023-11-11 21:13:11 -05:00
parent 3c9c3f9e7b
commit fed7d7f997
10 changed files with 341 additions and 252 deletions

View File

@ -652,25 +652,29 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
## Change History
* 2023/11/03 (v22.2.0)
- `sdxl_train.py` now supports different learning rates for each Text Encoder.
- Example:
- `--learning_rate 1e-6`: train U-Net only
- `--train_text_encoder --learning_rate 1e-6`: train U-Net and two Text Encoders with the same learning rate (same as the previous version)
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train U-Net and two Text Encoders with the different learning rates
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train two Text Encoders only
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 0`: train U-Net and one Text Encoder only
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 0 --learning_rate_te2 1e-6`: train one Text Encoder only
- sd-scripts code base update:
- `sdxl_train.py` now supports different learning rates for each Text Encoder.
- Example:
- `--learning_rate 1e-6`: train U-Net only
- `--train_text_encoder --learning_rate 1e-6`: train U-Net and two Text Encoders with the same learning rate (same as the previous version)
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train U-Net and two Text Encoders with the different learning rates
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train two Text Encoders only
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 0`: train U-Net and one Text Encoder only
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 0 --learning_rate_te2 1e-6`: train one Text Encoder only
- `train_db.py` and `fine_tune.py` now support different learning rates for Text Encoder. Specify with `--learning_rate_te` option.
- To train Text Encoder with `fine_tune.py`, specify `--train_text_encoder` option too. `train_db.py` trains Text Encoder by default.
- `train_db.py` and `fine_tune.py` now support different learning rates for Text Encoder. Specify with `--learning_rate_te` option.
- To train Text Encoder with `fine_tune.py`, specify `--train_text_encoder` option too. `train_db.py` trains Text Encoder by default.
- Fixed the bug that Text Encoder is not trained when block lr is specified in `sdxl_train.py`.
- Debiased Estimation loss is added to each training script. Thanks to sdbds!
- Specify `--debiased_estimation_loss` option to enable it. See PR [#889](https://github.com/kohya-ss/sd-scripts/pull/889) for details.
- Training of Text Encoder is improved in `train_network.py` and `sdxl_train_network.py`. Thanks to KohakuBlueleaf! PR [#895](https://github.com/kohya-ss/sd-scripts/pull/895)
- The moving average of the loss is now displayed in the progress bar in each training script. Thanks to shirayu! PR [#899](https://github.com/kohya-ss/sd-scripts/pull/899)
- PagedAdamW32bit optimizer is supported. Specify `--optimizer_type=PagedAdamW32bit`. Thanks to xzuyn! PR [#900](https://github.com/kohya-ss/sd-scripts/pull/900)
- Other bug fixes and improvements.
- Fixed the bug that Text Encoder is not trained when block lr is specified in `sdxl_train.py`.
- Debiased Estimation loss is added to each training script. Thanks to sdbds!
- Specify `--debiased_estimation_loss` option to enable it. See PR [#889](https://github.com/kohya-ss/sd-scripts/pull/889) for details.
- Training of Text Encoder is improved in `train_network.py` and `sdxl_train_network.py`. Thanks to KohakuBlueleaf! PR [#895](https://github.com/kohya-ss/sd-scripts/pull/895)
- The moving average of the loss is now displayed in the progress bar in each training script. Thanks to shirayu! PR [#899](https://github.com/kohya-ss/sd-scripts/pull/899)
- PagedAdamW32bit optimizer is supported. Specify `--optimizer_type=PagedAdamW32bit`. Thanks to xzuyn! PR [#900](https://github.com/kohya-ss/sd-scripts/pull/900)
- Other bug fixes and improvements.
- kohya_ss gui updates:
- Implement GUI support for SDXL finetune TE1 and TE2 training LR parameters and for non SDXL finetune TE training parameter
- Implement GUI support for Dreambooth TE LR parameter
- Implement Debiased Estimation loss at the botom of the Advanced Parameters tab.

View File

@ -67,6 +67,9 @@ def save_configuration(
output_dir,
max_resolution,
learning_rate,
learning_rate_te,
learning_rate_te1,
learning_rate_te2,
lr_scheduler,
lr_warmup,
train_batch_size,
@ -146,17 +149,17 @@ def save_configuration(
original_file_path = file_path
save_as_bool = True if save_as.get('label') == 'True' else False
save_as_bool = True if save_as.get("label") == "True" else False
if save_as_bool:
log.info('Save as...')
log.info("Save as...")
file_path = get_saveasfile_path(file_path)
else:
log.info('Save...')
if file_path == None or file_path == '':
log.info("Save...")
if file_path == None or file_path == "":
file_path = get_saveasfile_path(file_path)
if file_path == None or file_path == '':
if file_path == None or file_path == "":
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
# Extract the destination directory from the file path
@ -169,7 +172,7 @@ def save_configuration(
SaveConfigFile(
parameters=parameters,
file_path=file_path,
exclusion=['file_path', 'save_as'],
exclusion=["file_path", "save_as"],
)
return file_path
@ -188,6 +191,9 @@ def open_configuration(
output_dir,
max_resolution,
learning_rate,
learning_rate_te,
learning_rate_te1,
learning_rate_te2,
lr_scheduler,
lr_warmup,
train_batch_size,
@ -265,18 +271,18 @@ def open_configuration(
# Get list of function parameters and values
parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False
ask_for_file = True if ask_for_file.get("label") == "True" else False
original_file_path = file_path
if ask_for_file:
file_path = get_file_path(file_path)
if not file_path == '' and not file_path == None:
if not file_path == "" and not file_path == None:
# load variables from JSON file
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
my_data = json.load(f)
log.info('Loading config...')
log.info("Loading config...")
# Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True
my_data = update_my_data(my_data)
else:
@ -286,7 +292,7 @@ def open_configuration(
values = [file_path]
for key, value in parameters:
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['ask_for_file', 'file_path']:
if not key in ["ask_for_file", "file_path"]:
values.append(my_data.get(key, value))
return tuple(values)
@ -304,6 +310,9 @@ def train_model(
output_dir,
max_resolution,
learning_rate,
learning_rate_te,
learning_rate_te1,
learning_rate_te2,
lr_scheduler,
lr_warmup,
train_batch_size,
@ -381,36 +390,32 @@ def train_model(
# Get list of function parameters and values
parameters = list(locals().items())
print_only_bool = True if print_only.get('label') == 'True' else False
log.info(f'Start training Dreambooth...')
print_only_bool = True if print_only.get("label") == "True" else False
log.info(f"Start training Dreambooth...")
headless_bool = True if headless.get('label') == 'True' else False
headless_bool = True if headless.get("label") == "True" else False
if pretrained_model_name_or_path == '':
if pretrained_model_name_or_path == "":
output_message(
msg='Source model information is missing', headless=headless_bool
msg="Source model information is missing", headless=headless_bool
)
return
if train_data_dir == '':
output_message(
msg='Image folder path is missing', headless=headless_bool
)
if train_data_dir == "":
output_message(msg="Image folder path is missing", headless=headless_bool)
return
if not os.path.exists(train_data_dir):
output_message(
msg='Image folder does not exist', headless=headless_bool
)
output_message(msg="Image folder does not exist", headless=headless_bool)
return
if not verify_image_folder_pattern(train_data_dir):
return
if reg_data_dir != '':
if reg_data_dir != "":
if not os.path.exists(reg_data_dir):
output_message(
msg='Regularisation folder does not exist',
msg="Regularisation folder does not exist",
headless=headless_bool,
)
return
@ -418,10 +423,8 @@ def train_model(
if not verify_image_folder_pattern(reg_data_dir):
return
if output_dir == '':
output_message(
msg='Output folder path is missing', headless=headless_bool
)
if output_dir == "":
output_message(msg="Output folder path is missing", headless=headless_bool)
return
if check_if_model_exist(
@ -448,15 +451,12 @@ def train_model(
subfolders = [
f
for f in os.listdir(train_data_dir)
if os.path.isdir(os.path.join(train_data_dir, f))
and not f.startswith('.')
if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith(".")
]
# Check if subfolders are present. If not let the user know and return
if not subfolders:
log.info(
f"No {subfolders} were found in train_data_dir can't train..."
)
log.info(f"No {subfolders} were found in train_data_dir can't train...")
return
total_steps = 0
@ -465,7 +465,7 @@ def train_model(
for folder in subfolders:
# Extract the number of repeats from the folder name
try:
repeats = int(folder.split('_')[0])
repeats = int(folder.split("_")[0])
except ValueError:
log.info(
f"Subfolder {folder} does not have a proper repeat value, please correct the name or remove it... can't train..."
@ -478,42 +478,38 @@ def train_model(
f
for f, lower_f in (
(file, file.lower())
for file in os.listdir(
os.path.join(train_data_dir, folder)
)
for file in os.listdir(os.path.join(train_data_dir, folder))
)
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp"))
]
)
if num_images == 0:
log.info(f'{folder} folder contain no images, skipping...')
log.info(f"{folder} folder contain no images, skipping...")
else:
# Calculate the total number of steps for this folder
steps = repeats * num_images
total_steps += steps
# Print the result
log.info(f'Folder {folder} : steps {steps}')
log.info(f"Folder {folder} : steps {steps}")
if total_steps == 0:
log.info(
f'No images were found in folder {train_data_dir}... please rectify!'
)
log.info(f"No images were found in folder {train_data_dir}... please rectify!")
return
# Print the result
# log.info(f"{total_steps} total steps")
if reg_data_dir == '':
if reg_data_dir == "":
reg_factor = 1
else:
log.info(
f'Regularisation images are used... Will double the number of steps required...'
f"Regularisation images are used... Will double the number of steps required..."
)
reg_factor = 2
if max_train_steps == '' or max_train_steps == '0':
if max_train_steps == "" or max_train_steps == "0":
# calculate max_train_steps
max_train_steps = int(
math.ceil(
@ -525,7 +521,7 @@ def train_model(
)
)
log.info(
f'max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}'
f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
)
# calculate stop encoder training
@ -537,70 +533,72 @@ def train_model(
stop_text_encoder_training = math.ceil(
float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)
)
log.info(f'stop_text_encoder_training = {stop_text_encoder_training}')
log.info(f"stop_text_encoder_training = {stop_text_encoder_training}")
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
log.info(f'lr_warmup_steps = {lr_warmup_steps}')
log.info(f"lr_warmup_steps = {lr_warmup_steps}")
# run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"'
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process}'
run_cmd = (
f"accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process}"
)
if sdxl:
run_cmd += f' "./sdxl_train.py"'
else:
run_cmd += f' "./train_db.py"'
if v2:
run_cmd += ' --v2'
run_cmd += " --v2"
if v_parameterization:
run_cmd += ' --v_parameterization'
run_cmd += " --v_parameterization"
if enable_bucket:
run_cmd += f' --enable_bucket --min_bucket_reso={min_bucket_reso} --max_bucket_reso={max_bucket_reso}'
run_cmd += f" --enable_bucket --min_bucket_reso={min_bucket_reso} --max_bucket_reso={max_bucket_reso}"
if no_token_padding:
run_cmd += ' --no_token_padding'
run_cmd += " --no_token_padding"
if weighted_captions:
run_cmd += ' --weighted_captions'
run_cmd += (
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
)
run_cmd += " --weighted_captions"
run_cmd += f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
run_cmd += f' --train_data_dir="{train_data_dir}"'
if len(reg_data_dir):
run_cmd += f' --reg_data_dir="{reg_data_dir}"'
run_cmd += f' --resolution="{max_resolution}"'
run_cmd += f' --output_dir="{output_dir}"'
if not logging_dir == '':
if not logging_dir == "":
run_cmd += f' --logging_dir="{logging_dir}"'
if not stop_text_encoder_training == 0:
run_cmd += (
f' --stop_text_encoder_training={stop_text_encoder_training}'
)
if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}'
run_cmd += f" --stop_text_encoder_training={stop_text_encoder_training}"
if not save_model_as == "same as source model":
run_cmd += f" --save_model_as={save_model_as}"
# if not resume == '':
# run_cmd += f' --resume={resume}'
if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
run_cmd += f" --prior_loss_weight={prior_loss_weight}"
if full_bf16:
run_cmd += ' --full_bf16'
if not vae == '':
run_cmd += " --full_bf16"
if not vae == "":
run_cmd += f' --vae="{vae}"'
if not output_name == '':
if not output_name == "":
run_cmd += f' --output_name="{output_name}"'
if not lr_scheduler_num_cycles == '':
if not lr_scheduler_num_cycles == "":
run_cmd += f' --lr_scheduler_num_cycles="{lr_scheduler_num_cycles}"'
else:
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"'
if not lr_scheduler_power == '':
if not lr_scheduler_power == "":
run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"'
if int(max_token_length) > 75:
run_cmd += f' --max_token_length={max_token_length}'
if not max_train_epochs == '':
run_cmd += f" --max_token_length={max_token_length}"
if not max_train_epochs == "":
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
if not max_data_loader_n_workers == '':
run_cmd += (
f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
)
if not max_data_loader_n_workers == "":
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
if int(gradient_accumulation_steps) > 1:
run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}'
run_cmd += f" --gradient_accumulation_steps={int(gradient_accumulation_steps)}"
if sdxl:
run_cmd += f' --learning_rate_te1="{learning_rate_te1}"'
run_cmd += f' --learning_rate_te2="{learning_rate_te2}"'
else:
run_cmd += f' --learning_rate_te="{learning_rate_te}"'
run_cmd += run_cmd_training(
learning_rate=learning_rate,
@ -670,7 +668,7 @@ def train_model(
if print_only_bool:
log.warning(
'Here is the trainer command as a reference. It will not be executed:\n'
"Here is the trainer command as a reference. It will not be executed:\n"
)
print(run_cmd)
@ -678,17 +676,15 @@ def train_model(
else:
# Saving config file for model
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime('%Y%m%d-%H%M%S')
file_path = os.path.join(
output_dir, f'{output_name}_{formatted_datetime}.json'
)
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json")
log.info(f'Saving training config to {file_path}...')
log.info(f"Saving training config to {file_path}...")
SaveConfigFile(
parameters=parameters,
file_path=file_path,
exclusion=['file_path', 'save_as', 'headless', 'print_only'],
exclusion=["file_path", "save_as", "headless", "print_only"],
)
log.info(run_cmd)
@ -698,13 +694,11 @@ def train_model(
executor.execute_command(run_cmd=run_cmd)
# check if output_dir/last is a folder... therefore it is a diffuser model
last_dir = pathlib.Path(f'{output_dir}/{output_name}')
last_dir = pathlib.Path(f"{output_dir}/{output_name}")
if not last_dir.is_dir():
# Copy inference model for v2 if required
save_inference_file(
output_dir, v2, v_parameterization, output_name
)
save_inference_file(output_dir, v2, v_parameterization, output_name)
def dreambooth_tab(
@ -718,30 +712,30 @@ def dreambooth_tab(
dummy_db_false = gr.Label(value=False, visible=False)
dummy_headless = gr.Label(value=headless, visible=False)
with gr.Tab('Training'):
gr.Markdown(
'Train a custom model using kohya dreambooth python code...'
)
with gr.Tab("Training"):
gr.Markdown("Train a custom model using kohya dreambooth python code...")
# Setup Configuration Files Gradio
config = ConfigurationFile(headless)
source_model = SourceModel(headless=headless)
with gr.Tab('Folders'):
with gr.Tab("Folders"):
folders = Folders(headless=headless)
with gr.Tab('Parameters'):
with gr.Tab('Basic', elem_id='basic_tab'):
with gr.Tab("Parameters"):
with gr.Tab("Basic", elem_id="basic_tab"):
basic_training = BasicTraining(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
lr_warmup_value='10',
learning_rate_value="1e-5",
lr_scheduler_value="cosine",
lr_warmup_value="10",
dreambooth=True,
sdxl_checkbox=source_model.sdxl_checkbox,
)
# # Add SDXL Parameters
# sdxl_params = SDXLParameters(source_model.sdxl_checkbox, show_sdxl_cache_text_encoder_outputs=False)
with gr.Tab('Advanced', elem_id='advanced_tab'):
with gr.Tab("Advanced", elem_id="advanced_tab"):
advanced_training = AdvancedTraining(headless=headless)
advanced_training.color_aug.change(
color_aug_changed,
@ -749,12 +743,12 @@ def dreambooth_tab(
outputs=[basic_training.cache_latents],
)
with gr.Tab('Samples', elem_id='samples_tab'):
with gr.Tab("Samples", elem_id="samples_tab"):
sample = SampleImages()
with gr.Tab('Dataset Preparation'):
with gr.Tab("Dataset Preparation"):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
"This section provide Dreambooth tools to help setup your dataset..."
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
@ -766,11 +760,11 @@ def dreambooth_tab(
gradio_dataset_balancing_tab(headless=headless)
with gr.Row():
button_run = gr.Button('Start training', variant='primary')
button_run = gr.Button("Start training", variant="primary")
button_stop_training = gr.Button('Stop training')
button_stop_training = gr.Button("Stop training")
button_print = gr.Button('Print training command')
button_print = gr.Button("Print training command")
# Setup gradio tensorboard buttons
(
@ -800,6 +794,9 @@ def dreambooth_tab(
folders.output_dir,
basic_training.max_resolution,
basic_training.learning_rate,
basic_training.learning_rate_te,
basic_training.learning_rate_te1,
basic_training.learning_rate_te2,
basic_training.lr_scheduler,
basic_training.lr_warmup,
basic_training.train_batch_size,
@ -925,30 +922,28 @@ def dreambooth_tab(
def UI(**kwargs):
add_javascript(kwargs.get('language'))
css = ''
add_javascript(kwargs.get("language"))
css = ""
headless = kwargs.get('headless', False)
log.info(f'headless: {headless}')
headless = kwargs.get("headless", False)
log.info(f"headless: {headless}")
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
log.info('Load CSS...')
css += file.read() + '\n'
if os.path.exists("./style.css"):
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
log.info("Load CSS...")
css += file.read() + "\n"
interface = gr.Blocks(
css=css, title='Kohya_ss GUI', theme=gr.themes.Default()
)
interface = gr.Blocks(css=css, title="Kohya_ss GUI", theme=gr.themes.Default())
with interface:
with gr.Tab('Dreambooth'):
with gr.Tab("Dreambooth"):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab(headless=headless)
with gr.Tab('Utilities'):
with gr.Tab("Utilities"):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
@ -960,57 +955,53 @@ def UI(**kwargs):
# Show the interface
launch_kwargs = {}
username = kwargs.get('username')
password = kwargs.get('password')
server_port = kwargs.get('server_port', 0)
inbrowser = kwargs.get('inbrowser', False)
share = kwargs.get('share', False)
server_name = kwargs.get('listen')
username = kwargs.get("username")
password = kwargs.get("password")
server_port = kwargs.get("server_port", 0)
inbrowser = kwargs.get("inbrowser", False)
share = kwargs.get("share", False)
server_name = kwargs.get("listen")
launch_kwargs['server_name'] = server_name
launch_kwargs["server_name"] = server_name
if username and password:
launch_kwargs['auth'] = (username, password)
launch_kwargs["auth"] = (username, password)
if server_port > 0:
launch_kwargs['server_port'] = server_port
launch_kwargs["server_port"] = server_port
if inbrowser:
launch_kwargs['inbrowser'] = inbrowser
launch_kwargs["inbrowser"] = inbrowser
if share:
launch_kwargs['share'] = share
launch_kwargs["share"] = share
interface.launch(**launch_kwargs)
if __name__ == '__main__':
if __name__ == "__main__":
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument(
'--listen',
"--listen",
type=str,
default='127.0.0.1',
help='IP to listen on for connections to Gradio',
default="127.0.0.1",
help="IP to listen on for connections to Gradio",
)
parser.add_argument(
'--username', type=str, default='', help='Username for authentication'
"--username", type=str, default="", help="Username for authentication"
)
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
"--password", type=str, default="", help="Password for authentication"
)
parser.add_argument(
'--server_port',
"--server_port",
type=int,
default=0,
help='Port to run the server listener on',
help="Port to run the server listener on",
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
parser.add_argument(
"--headless", action="store_true", help="Is the server headless"
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument(
'--share', action='store_true', help='Share the gradio UI'
)
parser.add_argument(
'--headless', action='store_true', help='Is the server headless'
)
parser.add_argument(
'--language', type=str, default=None, help='Set custom language'
"--language", type=str, default=None, help="Set custom language"
)
args = parser.parse_args()

View File

@ -82,6 +82,9 @@ def save_configuration(
save_precision,
seed,
num_cpu_threads_per_process,
learning_rate_te,
learning_rate_te1,
learning_rate_te2,
train_text_encoder,
full_bf16,
create_caption,
@ -209,6 +212,9 @@ def open_configuration(
save_precision,
seed,
num_cpu_threads_per_process,
learning_rate_te,
learning_rate_te1,
learning_rate_te2,
train_text_encoder,
full_bf16,
create_caption,
@ -345,6 +351,9 @@ def train_model(
save_precision,
seed,
num_cpu_threads_per_process,
learning_rate_te,
learning_rate_te1,
learning_rate_te2,
train_text_encoder,
full_bf16,
generate_caption_database,
@ -536,6 +545,11 @@ def train_model(
run_cmd += ' --v_parameterization'
if train_text_encoder:
run_cmd += ' --train_text_encoder'
if sdxl_checkbox:
run_cmd += f' --learning_rate_te1="{learning_rate_te1}"'
run_cmd += f' --learning_rate_te2="{learning_rate_te2}"'
else:
run_cmd += f' --learning_rate_te="{learning_rate_te}"'
if full_bf16:
run_cmd += ' --full_bf16'
if weighted_captions:
@ -552,7 +566,6 @@ def train_model(
if not logging_dir == '':
run_cmd += f' --logging_dir="{logging_dir}"'
run_cmd += f' --dataset_repeats={dataset_repeats}'
run_cmd += f' --learning_rate={learning_rate}'
run_cmd += ' --enable_bucket'
run_cmd += f' --resolution="{max_resolution}"'
@ -853,7 +866,7 @@ def finetune_tab(headless=False):
with gr.Tab('Basic', elem_id='basic_tab'):
basic_training = BasicTraining(
learning_rate_value='1e-5', finetuning=True
learning_rate_value='1e-5', finetuning=True, sdxl_checkbox=source_model.sdxl_checkbox,
)
# Add SDXL Parameters
@ -942,6 +955,9 @@ def finetune_tab(headless=False):
basic_training.save_precision,
basic_training.seed,
basic_training.num_cpu_threads_per_process,
basic_training.learning_rate_te,
basic_training.learning_rate_te1,
basic_training.learning_rate_te2,
train_text_encoder,
advanced_training.full_bf16,
create_caption,

View File

@ -305,3 +305,8 @@ class AdvancedTraining:
value=False,
info='Only for SD v2 models. By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.',
)
self.debiased_estimation_loss = gr.Checkbox(
label='Debiased Estimation loss',
value=False,
info='Automates the processing of noise, allowing for faster model fitting, as well as balancing out color issues',
)

View File

@ -5,127 +5,167 @@ import os
class BasicTraining:
def __init__(
self,
learning_rate_value='1e-6',
lr_scheduler_value='constant',
lr_warmup_value='0',
sdxl_checkbox: gr.Checkbox,
learning_rate_value="1e-6",
lr_scheduler_value="constant",
lr_warmup_value="0",
finetuning: bool = False,
dreambooth: bool = False,
):
self.learning_rate_value = learning_rate_value
self.lr_scheduler_value = lr_scheduler_value
self.lr_warmup_value = lr_warmup_value
self.finetuning = finetuning
self.dreambooth = dreambooth
self.sdxl_checkbox = sdxl_checkbox
with gr.Row():
self.train_batch_size = gr.Slider(
minimum=1,
maximum=64,
label='Train batch size',
label="Train batch size",
value=1,
step=1,
)
self.epoch = gr.Number(label='Epoch', value=1, precision=0)
self.epoch = gr.Number(label="Epoch", value=1, precision=0)
self.max_train_epochs = gr.Textbox(
label='Max train epoch',
placeholder='(Optional) Enforce number of epoch',
label="Max train epoch",
placeholder="(Optional) Enforce number of epoch",
)
self.max_train_steps = gr.Textbox(
label='Max train steps',
placeholder='(Optional) Enforce number of steps',
label="Max train steps",
placeholder="(Optional) Enforce number of steps",
)
self.save_every_n_epochs = gr.Number(
label='Save every N epochs', value=1, precision=0
label="Save every N epochs", value=1, precision=0
)
self.caption_extension = gr.Textbox(
label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption',
label="Caption Extension",
placeholder="(Optional) Extension for caption files. default: .caption",
)
with gr.Row():
self.mixed_precision = gr.Dropdown(
label='Mixed precision',
label="Mixed precision",
choices=[
'no',
'fp16',
'bf16',
"no",
"fp16",
"bf16",
],
value='fp16',
value="fp16",
)
self.save_precision = gr.Dropdown(
label='Save precision',
label="Save precision",
choices=[
'float',
'fp16',
'bf16',
"float",
"fp16",
"bf16",
],
value='fp16',
value="fp16",
)
self.num_cpu_threads_per_process = gr.Slider(
minimum=1,
maximum=os.cpu_count(),
step=1,
label='Number of CPU threads per core',
label="Number of CPU threads per core",
value=2,
)
self.seed = gr.Textbox(
label='Seed', placeholder='(Optional) eg:1234'
)
self.cache_latents = gr.Checkbox(label='Cache latents', value=True)
self.seed = gr.Textbox(label="Seed", placeholder="(Optional) eg:1234")
self.cache_latents = gr.Checkbox(label="Cache latents", value=True)
self.cache_latents_to_disk = gr.Checkbox(
label='Cache latents to disk', value=False
label="Cache latents to disk", value=False
)
with gr.Row():
self.lr_scheduler = gr.Dropdown(
label='LR Scheduler',
label="LR Scheduler",
choices=[
'adafactor',
'constant',
'constant_with_warmup',
'cosine',
'cosine_with_restarts',
'linear',
'polynomial',
"adafactor",
"constant",
"constant_with_warmup",
"cosine",
"cosine_with_restarts",
"linear",
"polynomial",
],
value=lr_scheduler_value,
)
self.optimizer = gr.Dropdown(
label='Optimizer',
label="Optimizer",
choices=[
'AdamW',
'AdamW8bit',
'Adafactor',
'DAdaptation',
'DAdaptAdaGrad',
'DAdaptAdam',
'DAdaptAdan',
'DAdaptAdanIP',
'DAdaptAdamPreprint',
'DAdaptLion',
'DAdaptSGD',
'Lion',
'Lion8bit',
'PagedAdamW8bit',
'PagedLion8bit',
'Prodigy',
'SGDNesterov',
'SGDNesterov8bit',
"AdamW",
"AdamW8bit",
"Adafactor",
"DAdaptation",
"DAdaptAdaGrad",
"DAdaptAdam",
"DAdaptAdan",
"DAdaptAdanIP",
"DAdaptAdamPreprint",
"DAdaptLion",
"DAdaptSGD",
"Lion",
"Lion8bit",
"PagedAdamW8bit",
"PagedLion8bit",
"Prodigy",
"SGDNesterov",
"SGDNesterov8bit",
],
value='AdamW8bit',
value="AdamW8bit",
interactive=True,
)
with gr.Row():
self.lr_scheduler_args = gr.Textbox(
label='LR scheduler extra arguments',
label="LR scheduler extra arguments",
placeholder='(Optional) eg: "lr_end=5e-5"',
)
self.optimizer_args = gr.Textbox(
label='Optimizer extra arguments',
placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True',
label="Optimizer extra arguments",
placeholder="(Optional) eg: relative_step=True scale_parameter=True warmup_init=True",
)
with gr.Row():
self.learning_rate = gr.Number(
label='Learning rate', value=learning_rate_value
# Original GLOBAL LR
if finetuning or dreambooth:
self.learning_rate = gr.Number(
label="Learning rate Unet", value=learning_rate_value,
minimum=0,
maximum=1,
info="Set to 0 to not train the Unet"
)
else:
self.learning_rate = gr.Number(
label="Learning rate", value=learning_rate_value,
minimum=0,
maximum=1
)
# New TE LR for non SDXL models
self.learning_rate_te = gr.Number(
label="Learning rate TE",
value=learning_rate_value,
visible=finetuning or dreambooth,
minimum=0,
maximum=1,
info="Set to 0 to not train the Text Encoder"
)
# New TE LR for SDXL models
self.learning_rate_te1 = gr.Number(
label="Learning rate TE1",
value=learning_rate_value,
visible=False,
minimum=0,
maximum=1,
info="Set to 0 to not train the Text Encoder 1"
)
# New TE LR for SDXL models
self.learning_rate_te2 = gr.Number(
label="Learning rate TE2",
value=learning_rate_value,
visible=False,
minimum=0,
maximum=1,
info="Set to 0 to not train the Text Encoder 2"
)
self.lr_warmup = gr.Slider(
label='LR warmup (% of steps)',
label="LR warmup (% of steps)",
value=lr_warmup_value,
minimum=0,
maximum=100,
@ -133,44 +173,59 @@ class BasicTraining:
)
with gr.Row(visible=not finetuning):
self.lr_scheduler_num_cycles = gr.Textbox(
label='LR number of cycles',
placeholder='(Optional) For Cosine with restart and polynomial only',
label="LR number of cycles",
placeholder="(Optional) For Cosine with restart and polynomial only",
)
self.lr_scheduler_power = gr.Textbox(
label='LR power',
placeholder='(Optional) For Cosine with restart and polynomial only',
label="LR power",
placeholder="(Optional) For Cosine with restart and polynomial only",
)
with gr.Row(visible=not finetuning):
self.max_resolution = gr.Textbox(
label='Max resolution',
value='512,512',
placeholder='512,512',
label="Max resolution",
value="512,512",
placeholder="512,512",
)
self.stop_text_encoder_training = gr.Slider(
minimum=-1,
maximum=100,
value=0,
step=1,
label='Stop text encoder training',
label="Stop text encoder training",
)
with gr.Row(visible=not finetuning):
self.enable_bucket = gr.Checkbox(
label='Enable buckets', value=True
)
self.enable_bucket = gr.Checkbox(label="Enable buckets", value=True)
self.min_bucket_reso = gr.Slider(
label='Minimum bucket resolution',
label="Minimum bucket resolution",
value=256,
minimum=64,
maximum=4096,
step=64,
info='Minimum size in pixel a bucket can be (>= 64)',
info="Minimum size in pixel a bucket can be (>= 64)",
)
self.max_bucket_reso = gr.Slider(
label='Maximum bucket resolution',
label="Maximum bucket resolution",
value=2048,
minimum=64,
maximum=4096,
step=64,
info='Maximum size in pixel a bucket can be (>= 64)',
info="Maximum size in pixel a bucket can be (>= 64)",
)
def update_learning_rate_te(sdxl_checkbox, finetuning, dreambooth):
return (
gr.Number.update(visible=(not sdxl_checkbox and (finetuning or dreambooth))),
gr.Number.update(visible=(sdxl_checkbox and (finetuning or dreambooth))),
gr.Number.update(visible=(sdxl_checkbox and (finetuning or dreambooth))),
)
self.sdxl_checkbox.change(
update_learning_rate_te,
inputs=[self.sdxl_checkbox, gr.Checkbox(value=finetuning, visible=False), gr.Checkbox(value=dreambooth, visible=False)],
outputs=[
self.learning_rate_te,
self.learning_rate_te1,
self.learning_rate_te2,
],
)

View File

@ -550,7 +550,7 @@ def set_pretrained_model_name_or_path_input(
# Check if the given model_list is in the list of V1 models
if str(model_list) in V1_MODELS:
log.info('SD v1.4 model selected.')
log.info(f'{model_list} model selected.')
v2 = gr.Checkbox.update(value=False, visible=False)
v_parameterization = gr.Checkbox.update(value=False, visible=False)
sdxl = gr.Checkbox.update(value=False, visible=False)

View File

@ -171,6 +171,7 @@ def save_configuration(
min_timestep,
max_timestep,
vae,
debiased_estimation_loss,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -325,6 +326,7 @@ def open_configuration(
max_timestep,
training_preset,
vae,
debiased_estimation_loss,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -497,6 +499,7 @@ def train_model(
min_timestep,
max_timestep,
vae,
debiased_estimation_loss,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -932,6 +935,9 @@ def train_model(
if full_bf16:
run_cmd += f" --full_bf16"
if debiased_estimation_loss:
run_cmd += " --debiased_estimation_loss"
run_cmd += run_cmd_training(
learning_rate=learning_rate,
@ -1134,6 +1140,7 @@ def lora_tab(
learning_rate_value="0.0001",
lr_scheduler_value="cosine",
lr_warmup_value="10",
sdxl_checkbox=source_model.sdxl_checkbox,
)
with gr.Row():
@ -1643,6 +1650,7 @@ def lora_tab(
advanced_training.min_timestep,
advanced_training.max_timestep,
advanced_training.vae,
advanced_training.debiased_estimation_loss,
]
config.button_open_config.click(

Binary file not shown.

View File

@ -2,6 +2,7 @@
"adaptive_noise_scale": 0,
"additional_parameters": "",
"batch_size": "8",
"block_lr": "",
"bucket_no_upscale": true,
"bucket_reso_steps": 1,
"cache_latents": true,
@ -17,6 +18,7 @@
"dataset_repeats": "50",
"epoch": 2,
"flip_aug": false,
"full_bf16": false,
"full_fp16": false,
"full_path": true,
"gradient_accumulation_steps": 1.0,
@ -25,19 +27,25 @@
"keep_tokens": 0,
"latent_metadata_filename": "meta-1_lat.json",
"learning_rate": 1e-05,
"learning_rate_te": 5e-06,
"learning_rate_te1": 5e-06,
"learning_rate_te2": 0.0,
"logging_dir": "./test/ft",
"lr_scheduler": "cosine_with_restarts",
"lr_scheduler_args": "",
"lr_warmup": 10,
"max_bucket_reso": "1024",
"max_data_loader_n_workers": "0",
"max_resolution": "512,512",
"max_timestep": 1000,
"max_token_length": "75",
"max_train_epochs": "",
"mem_eff_attn": false,
"min_bucket_reso": "256",
"min_snr_gamma": 0,
"min_timestep": 0,
"mixed_precision": "bf16",
"model_list": "stabilityai/stable-diffusion-xl-base-1.0",
"model_list": "runwayml/stable-diffusion-v1-5",
"multires_noise_discount": 0,
"multires_noise_iterations": 0,
"noise_offset": 0,
@ -48,7 +56,7 @@
"output_dir": "./test/output",
"output_name": "test_ft",
"persistent_data_loader_workers": false,
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
"random_crop": false,
"resume": "",
"sample_every_n_epochs": 0,
@ -64,7 +72,7 @@
"save_state": false,
"scale_v_pred_loss_like_noise_pred": false,
"sdxl_cache_text_encoder_outputs": false,
"sdxl_checkbox": true,
"sdxl_checkbox": false,
"sdxl_no_half_vae": false,
"seed": "1234",
"shuffle_caption": false,
@ -73,10 +81,11 @@
"train_text_encoder": true,
"use_latent_files": "No",
"use_wandb": false,
"v2": true,
"v_parameterization": true,
"v2": false,
"v_parameterization": false,
"v_pred_like_loss": 0,
"vae_batch_size": 0,
"wandb_api_key": "",
"weighted_captions": false,
"xformers": true
"xformers": "xformers"
}

View File

@ -792,6 +792,7 @@ def ti_tab(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
lr_warmup_value='10',
sdxl_checkbox=source_model.sdxl_checkbox,
)
# Add SDXL Parameters