add loss options

* loss_simple is same as basic loss, without VLB.
beta-dadaptation
aria1th 2023-01-24 23:45:49 +09:00
parent 07ac06beb2
commit a62a343643
2 changed files with 13 additions and 5 deletions

View File

@ -103,7 +103,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01,
optional_gradient_norm_type=2, latent_sampling_std=-1,
noise_training_scheduler_enabled=False, noise_training_scheduler_repeat=False, noise_training_scheduler_cycle=128,
load_training_options=''
load_training_options='', loss_opt='loss_simple'
):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@ -142,6 +142,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False)
noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False)
noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128)
loss_opt = dump.get('loss_opt', 'loss_simple')
try:
if use_adamw_parameter:
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in
@ -358,7 +359,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0]
_, losses = shared.sd_model(x, c)
loss = losses['val/' + loss_opt]
for filenames in batch.filename:
loss_dict[filenames].append(loss.detach().item())
loss /= gradient_step
@ -607,6 +609,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False)
noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False)
noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128)
loss_opt = dump.get('loss_opt', 'loss_simple')
else:
raise RuntimeError(f"Cannot load from {load_training_options}!")
else:
@ -854,7 +857,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0]
_, losses = shared.sd_model(x, c)
loss = losses['val/' + loss_opt]
for filenames in batch.filename:
loss_dict[filenames].append(loss.detach().item())
loss /= gradient_step

View File

@ -83,7 +83,7 @@ def save_training_setting(*args):
gamma_rate, use_beta_adamW_checkbox, save_when_converge, create_when_converge, \
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps, show_gradient_clip_checkbox, \
gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std,\
noise_training_scheduler_enabled, noise_training_scheduler_repeat, noise_training_scheduler_cycle = args
noise_training_scheduler_enabled, noise_training_scheduler_repeat, noise_training_scheduler_cycle, loss_opt = args
dumped_locals = locals()
dumped_locals.pop('args')
filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_train_' + '.json'
@ -222,6 +222,9 @@ def on_train_gamma_tab(params=None):
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once",
choices=['once', 'deterministic', 'random'])
latent_sampling_std_value = gr.Number(label="Standard deviation for sampling", value=-1)
with gr.Row():
loss_opt = gr.Radio(label="loss type", value="loss",
choices=['loss', 'loss_simple', 'loss_vlb'])
with gr.Row():
save_training_option = gr.Button(value="Save training setting")
save_file_name = gr.Textbox(label="File name to save setting as", value="")
@ -269,7 +272,8 @@ def on_train_gamma_tab(params=None):
latent_sampling_std_value,
noise_training_scheduler_enabled,
noise_training_scheduler_repeat,
noise_training_scheduler_cycle],
noise_training_scheduler_cycle,
loss_opt],
outputs=[
ti_output,
ti_outcome,