diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 3f7eabc..103b3ac 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -103,7 +103,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01, optional_gradient_norm_type=2, latent_sampling_std=-1, noise_training_scheduler_enabled=False, noise_training_scheduler_repeat=False, noise_training_scheduler_cycle=128, - load_training_options='' + load_training_options='', loss_opt='loss_simple' ): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -142,6 +142,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False) noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False) noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128) + loss_opt = dump.get('loss_opt', 'loss_simple') try: if use_adamw_parameter: adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in @@ -358,7 +359,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi shared.sd_model.cond_stage_model.to(devices.cpu) else: c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) - loss = shared.sd_model(x, c)[0] + _, losses = shared.sd_model(x, c) + loss = losses['val/' + loss_opt] for filenames in batch.filename: loss_dict[filenames].append(loss.detach().item()) loss /= gradient_step @@ -607,6 +609,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False) noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False) noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128) + loss_opt = dump.get('loss_opt', 'loss_simple') else: raise RuntimeError(f"Cannot load from {load_training_options}!") else: @@ -854,7 +857,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, shared.sd_model.cond_stage_model.to(devices.cpu) else: c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) - loss = shared.sd_model(x, c)[0] + _, losses = shared.sd_model(x, c) + loss = losses['val/' + loss_opt] for filenames in batch.filename: loss_dict[filenames].append(loss.detach().item()) loss /= gradient_step diff --git a/patches/external_pr/ui.py b/patches/external_pr/ui.py index f73ec2a..9581d88 100644 --- a/patches/external_pr/ui.py +++ b/patches/external_pr/ui.py @@ -83,7 +83,7 @@ def save_training_setting(*args): gamma_rate, use_beta_adamW_checkbox, save_when_converge, create_when_converge, \ adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps, show_gradient_clip_checkbox, \ gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std,\ - noise_training_scheduler_enabled, noise_training_scheduler_repeat, noise_training_scheduler_cycle = args + noise_training_scheduler_enabled, noise_training_scheduler_repeat, noise_training_scheduler_cycle, loss_opt = args dumped_locals = locals() dumped_locals.pop('args') filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_train_' + '.json' @@ -222,6 +222,9 @@ def on_train_gamma_tab(params=None): latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) latent_sampling_std_value = gr.Number(label="Standard deviation for sampling", value=-1) + with gr.Row(): + loss_opt = gr.Radio(label="loss type", value="loss", + choices=['loss', 'loss_simple', 'loss_vlb']) with gr.Row(): save_training_option = gr.Button(value="Save training setting") save_file_name = gr.Textbox(label="File name to save setting as", value="") @@ -269,7 +272,8 @@ def on_train_gamma_tab(params=None): latent_sampling_std_value, noise_training_scheduler_enabled, noise_training_scheduler_repeat, - noise_training_scheduler_cycle], + noise_training_scheduler_cycle, + loss_opt], outputs=[ ti_output, ti_outcome,