From a5182880f9334f9fc3d46f153aca0cb4e8895cad Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 20 Jan 2023 21:31:53 +0900 Subject: [PATCH] add noise scheduler for training move to LatentDiffusion forward unbind scheduler unbind noise scheduler by option minor fix? minor fix for extreme case --- patches/ddpm_hijack.py | 63 +++++++++++++++++++++++++++++ patches/external_pr/hypernetwork.py | 57 ++++++++++++++++++++------ patches/external_pr/ui.py | 22 +++++++++- patches/hypernetwork.py | 4 ++ 4 files changed, 132 insertions(+), 14 deletions(-) create mode 100644 patches/ddpm_hijack.py diff --git a/patches/ddpm_hijack.py b/patches/ddpm_hijack.py new file mode 100644 index 0000000..e69c4a5 --- /dev/null +++ b/patches/ddpm_hijack.py @@ -0,0 +1,63 @@ +import torch +import ldm.models.diffusion.ddpm +from modules import shared + + +class Scheduler: + """ Proportional Noise Step Scheduler""" + def __init__(self, cycle_step=128, repeat=True): + self.cycle_step = int(cycle_step) + self.repeat = repeat + self.run_assertion() + + def __call__(self, value, step): + if self.repeat: + step %= self.cycle_step + return max(1, int(value * step / self.cycle_step)) + else: + return value if step >= self.cycle_step else max(1, int(value * step / self.cycle_step)) + + def run_assertion(self): + assert type(self.cycle_step) is int + assert type(self.repeat) is bool + assert not self.repeat or self.cycle_step > 0 + + def set(self, cycle_step=-1, repeat=-1): + if cycle_step >= 0: + self.cycle_step = int(cycle_step) + if repeat != -1: + self.repeat = repeat + self.run_assertion() + + +training_scheduler = Scheduler(cycle_step=-1, repeat=False) + + +def get_current(value, step=None): + if step is None: + if hasattr(shared.loaded_hypernetwork, 'step') and shared.loaded_hypernetwork.training and shared.loaded_hypernetwork.step is not None: + return training_scheduler(value, shared.loaded_hypernetwork.step) + return value + return max(1, training_scheduler(value, step)) + + +def set_scheduler(cycle_step, repeat): + global training_scheduler + training_scheduler.set(cycle_step, repeat) + + +def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, get_current(self.num_timesteps), (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + + + +ldm.models.diffusion.ddpm.LatentDiffusion.forward = forward diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 240bc70..3decbba 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -22,15 +22,18 @@ from ..hnutil import optim_to from ..ui import create_hypernetwork_load from ..scheduler import CosineAnnealingWarmUpRestarts from .dataset import PersonalizedBase, PersonalizedDataLoader +from ..ddpm_hijack import set_scheduler def get_training_option(filename): print(filename) - if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)): + if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile( + os.path.join(shared.cmd_opts.hypernetwork_dir, filename)): filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename) elif os.path.exists(filename) and os.path.isfile(filename): filename = filename - elif os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')) and os.path.isfile(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')): + elif os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')) and os.path.isfile( + os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')): filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json') else: return False @@ -52,7 +55,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi adamw_eps=1e-8, use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01, optional_gradient_norm_type=2, latent_sampling_std=-1, - load_training_options=''): + noise_training_scheduler_enabled=False, noise_training_scheduler_repeat=False, noise_training_scheduler_cycle=128, + load_training_options='' + ): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images if load_training_options != '': @@ -87,6 +92,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi optional_gradient_clip_value = dump['optional_gradient_clip_value'] optional_gradient_norm_type = dump['optional_gradient_norm_type'] latent_sampling_std = dump.get('latent_sampling_std', -1) + noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False) + noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False) + noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128) try: if use_adamw_parameter: adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in @@ -132,6 +140,11 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi create_when_converge = False except ValueError: raise RuntimeError("Cannot use advanced LR scheduler settings!") + if noise_training_scheduler_enabled: + set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat) + print(f"Noise training scheduler is now ready for {noise_training_scheduler_cycle}, {noise_training_scheduler_repeat}!") + else: + set_scheduler(-1, False) if use_grad_opts and gradient_clip_opt != "None": try: optional_gradient_clip_value = float(optional_gradient_clip_value) @@ -324,7 +337,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) loss = shared.sd_model(x, c)[0] for filenames in batch.filename: - loss_dict[filenames].append(loss.item()) + loss_dict[filenames].append(loss.detach().item()) loss /= gradient_step del x del c @@ -382,10 +395,12 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, - learn_rate=scheduler.learn_rate if not use_beta_scheduler else optimizer.param_groups[0]['lr'], epoch_num=epoch_num) + learn_rate=scheduler.learn_rate if not use_beta_scheduler else + optimizer.param_groups[0]['lr'], epoch_num=epoch_num) if images_dir is not None and ( use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or ( create_image_every > 0 and steps_done % create_image_every == 0): + set_scheduler(-1, False) forced_filename = f'{hypernetwork_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) rng_state = torch.get_rng_state() @@ -437,6 +452,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi hypernetwork.train() if move_optimizer: optim_to(optimizer, devices.device) + if noise_training_scheduler_enabled: + set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat) if image is not None: if hasattr(shared.state, 'assign_current_image'): shared.state.assign_current_image(image) @@ -469,6 +486,7 @@ Last saved image: {html.escape(last_saved_image)}
shared.parallel_processing_allowed = old_parallel_processing_allowed if hasattr(sd_hijack_checkpoint, 'remove'): sd_hijack_checkpoint.remove() + set_scheduler(-1, False) report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name @@ -488,7 +506,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height, move_optimizer=True, - load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1, setting_tuple=None): + load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1, + setting_tuple=None): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images base_hypernetwork_name = hypernetwork_name @@ -514,10 +533,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, skip_connection = dump_hyper['skip_connection'] hypernetwork = create_hypernetwork_load(hypernetwork_name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, - dropout_structure, optional_info, weight_init_seed, normal_std, skip_connection) + dropout_structure, optional_info, weight_init_seed, normal_std, + skip_connection) else: load_hypernetwork(hypernetwork_name) - hypernetwork_name = hypernetwork_name.rsplit('(',1)[0] + setting_suffix + hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix shared.loaded_hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt")) shared.reload_hypernetworks() load_hypernetwork(hypernetwork_name) @@ -552,6 +572,9 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, optional_gradient_clip_value = dump['optional_gradient_clip_value'] optional_gradient_norm_type = dump['optional_gradient_norm_type'] latent_sampling_std = dump.get('latent_sampling_std', -1) + noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False) + noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False) + noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128) else: raise RuntimeError(f"Cannot load from {load_training_options}!") else: @@ -627,6 +650,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, else: def gradient_clipping(arg1): return + if noise_training_scheduler_enabled: + set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat) + print(f"Noise training scheduler is now ready for {noise_training_scheduler_cycle}, {noise_training_scheduler_repeat}!") + else: + set_scheduler(-1, False) save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, @@ -794,7 +822,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) loss = shared.sd_model(x, c)[0] for filenames in batch.filename: - loss_dict[filenames].append(loss.item()) + loss_dict[filenames].append(loss.detach().item()) loss /= gradient_step del x del c @@ -852,10 +880,12 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, - learn_rate=scheduler.learn_rate if not use_beta_scheduler else optimizer.param_groups[0]['lr'], epoch_num=epoch_num,base_name=hypernetwork_name) + learn_rate=scheduler.learn_rate if not use_beta_scheduler else + optimizer.param_groups[0]['lr'], epoch_num=epoch_num, base_name=hypernetwork_name) if images_dir is not None and ( use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or ( create_image_every > 0 and steps_done % create_image_every == 0): + set_scheduler(-1, False) forced_filename = f'{hypernetwork_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) rng_state = torch.get_rng_state() @@ -907,6 +937,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, hypernetwork.train() if move_optimizer: optim_to(optimizer, devices.device) + if noise_training_scheduler_enabled: + set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat) if image is not None: if hasattr(shared.state, 'assign_current_image'): shared.state.assign_current_image(image) @@ -936,10 +968,12 @@ Last saved image: {html.escape(last_saved_image)}
pbar.leave = False pbar.close() hypernetwork.eval() + set_scheduler(-1, False) shared.parallel_processing_allowed = old_parallel_processing_allowed if hasattr(sd_hijack_checkpoint, 'remove'): sd_hijack_checkpoint.remove() if shared.opts.training_enable_tensorboard: + mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) if sum(len(x) for x in loss_dict.values()) > 0 else 0 tensorboard_log_hyperparameter(tensorboard_writer, lr=learn_rate, GA_steps=gradient_step, batch_size=batch_size, @@ -965,7 +999,7 @@ Last saved image: {html.escape(last_saved_image)}
gradient_clip_value=optional_gradient_clip_value, gradient_clip_norm_type=optional_gradient_norm_type, loss=mean_loss, - base_hypernetwork_name= hypernetwork_name + base_hypernetwork_name=hypernetwork_name ) report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') @@ -1012,4 +1046,3 @@ def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directo load_hypernetworks_option, load_training_option, manual_dataset_seed, setting_tuple=(_i, _j)) if shared.state.interrupted: return None, None - diff --git a/patches/external_pr/ui.py b/patches/external_pr/ui.py index 497a296..a444908 100644 --- a/patches/external_pr/ui.py +++ b/patches/external_pr/ui.py @@ -67,7 +67,8 @@ def save_training_setting(*args): template_file, use_beta_scheduler, beta_repeat_epoch, epoch_mult, warmup, min_lr, \ gamma_rate, use_beta_adamW_checkbox, save_when_converge, create_when_converge, \ adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps, show_gradient_clip_checkbox, \ - gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std = args + gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std,\ + noise_training_scheduler_enabled, noise_training_scheduler_repeat, noise_training_scheduler_cycle = args dumped_locals = locals() dumped_locals.pop('args') filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_train_' + '.json' @@ -121,6 +122,8 @@ def on_train_gamma_tab(params=None): label='Show advanced adamW parameter options)') show_gradient_clip_checkbox = gr.Checkbox( label='Show Gradient Clipping Options(for both)') + show_noise_options = gr.Checkbox( + label='Show Noise Scheduler Options(for both)') with gr.Row(visible=False) as adamW_options: adamw_weight_decay = gr.Textbox(label="AdamW weight decay parameter", placeholder="default = 0.01", value="0.01") @@ -146,7 +149,16 @@ def on_train_gamma_tab(params=None): gradient_clip_opt = gr.Radio(label="Gradient Clipping Options", choices=["None", "limit", "norm"]) optional_gradient_clip_value = gr.Textbox(label="Limiting value", value="1e-1") optional_gradient_norm_type = gr.Textbox(label="Norm type", value="2") + with gr.Row(visible=False) as noise_scheduler_options: + noise_training_scheduler_enabled = gr.Checkbox(label="Use Noise training scheduler(test)") + noise_training_scheduler_repeat = gr.Checkbox(label="Restarts noise scheduler, or linear") + noise_training_scheduler_cycle = gr.Number(label="Restarts noise scheduler every nth epoch") # change by feedback + show_noise_options.change( + fn = lambda show:gr_show(show), + inputs = [show_noise_options], + outputs = [noise_scheduler_options] + ) use_beta_adamW_checkbox.change( fn=lambda show: gr_show(show), inputs=[use_beta_adamW_checkbox], @@ -239,7 +251,10 @@ def on_train_gamma_tab(params=None): gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, - latent_sampling_std_value], + latent_sampling_std_value, + noise_training_scheduler_enabled, + noise_training_scheduler_repeat, + noise_training_scheduler_cycle], outputs=[ ti_output, ti_outcome, @@ -335,6 +350,9 @@ def on_train_gamma_tab(params=None): optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std_value, + noise_training_scheduler_enabled, + noise_training_scheduler_repeat, + noise_training_scheduler_cycle, load_training_option ], diff --git a/patches/hypernetwork.py b/patches/hypernetwork.py index cbb0fc5..ddc2d15 100644 --- a/patches/hypernetwork.py +++ b/patches/hypernetwork.py @@ -270,6 +270,7 @@ class Hypernetwork: self.optional_info = kwargs.get('optional_info', None) self.skip_connection = kwargs.get('skip_connection', False) self.upsample_linear = kwargs.get('upsample_linear', None) + self.training = False generation_seed = kwargs.get('generation_seed', None) normal_std = kwargs.get('normal_std', 0.01) if self.dropout_structure is None: @@ -287,6 +288,7 @@ class Hypernetwork: self.eval() def weights(self, train=False): + self.training = train res = [] for k, layers in self.layers.items(): for layer in layers: @@ -294,12 +296,14 @@ class Hypernetwork: return res def eval(self): + self.training = False for k, layers in self.layers.items(): for layer in layers: layer.eval() layer.set_train(False) def train(self, mode=True): + self.training = mode for k, layers in self.layers.items(): for layer in layers: layer.set_train(mode)