From a5182880f9334f9fc3d46f153aca0cb4e8895cad Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Fri, 20 Jan 2023 21:31:53 +0900
Subject: [PATCH] add noise scheduler for training
move to LatentDiffusion forward
unbind scheduler
unbind noise scheduler by option
minor fix?
minor fix for extreme case
---
patches/ddpm_hijack.py | 63 +++++++++++++++++++++++++++++
patches/external_pr/hypernetwork.py | 57 ++++++++++++++++++++------
patches/external_pr/ui.py | 22 +++++++++-
patches/hypernetwork.py | 4 ++
4 files changed, 132 insertions(+), 14 deletions(-)
create mode 100644 patches/ddpm_hijack.py
diff --git a/patches/ddpm_hijack.py b/patches/ddpm_hijack.py
new file mode 100644
index 0000000..e69c4a5
--- /dev/null
+++ b/patches/ddpm_hijack.py
@@ -0,0 +1,63 @@
+import torch
+import ldm.models.diffusion.ddpm
+from modules import shared
+
+
+class Scheduler:
+ """ Proportional Noise Step Scheduler"""
+ def __init__(self, cycle_step=128, repeat=True):
+ self.cycle_step = int(cycle_step)
+ self.repeat = repeat
+ self.run_assertion()
+
+ def __call__(self, value, step):
+ if self.repeat:
+ step %= self.cycle_step
+ return max(1, int(value * step / self.cycle_step))
+ else:
+ return value if step >= self.cycle_step else max(1, int(value * step / self.cycle_step))
+
+ def run_assertion(self):
+ assert type(self.cycle_step) is int
+ assert type(self.repeat) is bool
+ assert not self.repeat or self.cycle_step > 0
+
+ def set(self, cycle_step=-1, repeat=-1):
+ if cycle_step >= 0:
+ self.cycle_step = int(cycle_step)
+ if repeat != -1:
+ self.repeat = repeat
+ self.run_assertion()
+
+
+training_scheduler = Scheduler(cycle_step=-1, repeat=False)
+
+
+def get_current(value, step=None):
+ if step is None:
+ if hasattr(shared.loaded_hypernetwork, 'step') and shared.loaded_hypernetwork.training and shared.loaded_hypernetwork.step is not None:
+ return training_scheduler(value, shared.loaded_hypernetwork.step)
+ return value
+ return max(1, training_scheduler(value, step))
+
+
+def set_scheduler(cycle_step, repeat):
+ global training_scheduler
+ training_scheduler.set(cycle_step, repeat)
+
+
+def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(0, get_current(self.num_timesteps), (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+
+
+
+ldm.models.diffusion.ddpm.LatentDiffusion.forward = forward
diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py
index 240bc70..3decbba 100644
--- a/patches/external_pr/hypernetwork.py
+++ b/patches/external_pr/hypernetwork.py
@@ -22,15 +22,18 @@ from ..hnutil import optim_to
from ..ui import create_hypernetwork_load
from ..scheduler import CosineAnnealingWarmUpRestarts
from .dataset import PersonalizedBase, PersonalizedDataLoader
+from ..ddpm_hijack import set_scheduler
def get_training_option(filename):
print(filename)
- if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)):
+ if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile(
+ os.path.join(shared.cmd_opts.hypernetwork_dir, filename)):
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename)
elif os.path.exists(filename) and os.path.isfile(filename):
filename = filename
- elif os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')) and os.path.isfile(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')):
+ elif os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')) and os.path.isfile(
+ os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')):
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')
else:
return False
@@ -52,7 +55,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
adamw_eps=1e-8,
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01,
optional_gradient_norm_type=2, latent_sampling_std=-1,
- load_training_options=''):
+ noise_training_scheduler_enabled=False, noise_training_scheduler_repeat=False, noise_training_scheduler_cycle=128,
+ load_training_options=''
+ ):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
if load_training_options != '':
@@ -87,6 +92,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
optional_gradient_clip_value = dump['optional_gradient_clip_value']
optional_gradient_norm_type = dump['optional_gradient_norm_type']
latent_sampling_std = dump.get('latent_sampling_std', -1)
+ noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False)
+ noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False)
+ noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128)
try:
if use_adamw_parameter:
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in
@@ -132,6 +140,11 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
create_when_converge = False
except ValueError:
raise RuntimeError("Cannot use advanced LR scheduler settings!")
+ if noise_training_scheduler_enabled:
+ set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat)
+ print(f"Noise training scheduler is now ready for {noise_training_scheduler_cycle}, {noise_training_scheduler_repeat}!")
+ else:
+ set_scheduler(-1, False)
if use_grad_opts and gradient_clip_opt != "None":
try:
optional_gradient_clip_value = float(optional_gradient_clip_value)
@@ -324,7 +337,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0]
for filenames in batch.filename:
- loss_dict[filenames].append(loss.item())
+ loss_dict[filenames].append(loss.detach().item())
loss /= gradient_step
del x
del c
@@ -382,10 +395,12 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values())
tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step,
- learn_rate=scheduler.learn_rate if not use_beta_scheduler else optimizer.param_groups[0]['lr'], epoch_num=epoch_num)
+ learn_rate=scheduler.learn_rate if not use_beta_scheduler else
+ optimizer.param_groups[0]['lr'], epoch_num=epoch_num)
if images_dir is not None and (
use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or (
create_image_every > 0 and steps_done % create_image_every == 0):
+ set_scheduler(-1, False)
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
rng_state = torch.get_rng_state()
@@ -437,6 +452,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
hypernetwork.train()
if move_optimizer:
optim_to(optimizer, devices.device)
+ if noise_training_scheduler_enabled:
+ set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat)
if image is not None:
if hasattr(shared.state, 'assign_current_image'):
shared.state.assign_current_image(image)
@@ -469,6 +486,7 @@ Last saved image: {html.escape(last_saved_image)}
shared.parallel_processing_allowed = old_parallel_processing_allowed
if hasattr(sd_hijack_checkpoint, 'remove'):
sd_hijack_checkpoint.remove()
+ set_scheduler(-1, False)
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
@@ -488,7 +506,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
move_optimizer=True,
- load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1, setting_tuple=None):
+ load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1,
+ setting_tuple=None):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
base_hypernetwork_name = hypernetwork_name
@@ -514,10 +533,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
skip_connection = dump_hyper['skip_connection']
hypernetwork = create_hypernetwork_load(hypernetwork_name, enable_sizes, overwrite_old, layer_structure,
activation_func, weight_init, add_layer_norm, use_dropout,
- dropout_structure, optional_info, weight_init_seed, normal_std, skip_connection)
+ dropout_structure, optional_info, weight_init_seed, normal_std,
+ skip_connection)
else:
load_hypernetwork(hypernetwork_name)
- hypernetwork_name = hypernetwork_name.rsplit('(',1)[0] + setting_suffix
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix
shared.loaded_hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt"))
shared.reload_hypernetworks()
load_hypernetwork(hypernetwork_name)
@@ -552,6 +572,9 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
optional_gradient_clip_value = dump['optional_gradient_clip_value']
optional_gradient_norm_type = dump['optional_gradient_norm_type']
latent_sampling_std = dump.get('latent_sampling_std', -1)
+ noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False)
+ noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False)
+ noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128)
else:
raise RuntimeError(f"Cannot load from {load_training_options}!")
else:
@@ -627,6 +650,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
else:
def gradient_clipping(arg1):
return
+ if noise_training_scheduler_enabled:
+ set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat)
+ print(f"Noise training scheduler is now ready for {noise_training_scheduler_cycle}, {noise_training_scheduler_repeat}!")
+ else:
+ set_scheduler(-1, False)
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
@@ -794,7 +822,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0]
for filenames in batch.filename:
- loss_dict[filenames].append(loss.item())
+ loss_dict[filenames].append(loss.detach().item())
loss /= gradient_step
del x
del c
@@ -852,10 +880,12 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values())
tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step,
- learn_rate=scheduler.learn_rate if not use_beta_scheduler else optimizer.param_groups[0]['lr'], epoch_num=epoch_num,base_name=hypernetwork_name)
+ learn_rate=scheduler.learn_rate if not use_beta_scheduler else
+ optimizer.param_groups[0]['lr'], epoch_num=epoch_num, base_name=hypernetwork_name)
if images_dir is not None and (
use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or (
create_image_every > 0 and steps_done % create_image_every == 0):
+ set_scheduler(-1, False)
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
rng_state = torch.get_rng_state()
@@ -907,6 +937,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
hypernetwork.train()
if move_optimizer:
optim_to(optimizer, devices.device)
+ if noise_training_scheduler_enabled:
+ set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat)
if image is not None:
if hasattr(shared.state, 'assign_current_image'):
shared.state.assign_current_image(image)
@@ -936,10 +968,12 @@ Last saved image: {html.escape(last_saved_image)}
pbar.leave = False
pbar.close()
hypernetwork.eval()
+ set_scheduler(-1, False)
shared.parallel_processing_allowed = old_parallel_processing_allowed
if hasattr(sd_hijack_checkpoint, 'remove'):
sd_hijack_checkpoint.remove()
if shared.opts.training_enable_tensorboard:
+ mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) if sum(len(x) for x in loss_dict.values()) > 0 else 0
tensorboard_log_hyperparameter(tensorboard_writer, lr=learn_rate,
GA_steps=gradient_step,
batch_size=batch_size,
@@ -965,7 +999,7 @@ Last saved image: {html.escape(last_saved_image)}
gradient_clip_value=optional_gradient_clip_value,
gradient_clip_norm_type=optional_gradient_norm_type,
loss=mean_loss,
- base_hypernetwork_name= hypernetwork_name
+ base_hypernetwork_name=hypernetwork_name
)
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
@@ -1012,4 +1046,3 @@ def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directo
load_hypernetworks_option, load_training_option, manual_dataset_seed, setting_tuple=(_i, _j))
if shared.state.interrupted:
return None, None
-
diff --git a/patches/external_pr/ui.py b/patches/external_pr/ui.py
index 497a296..a444908 100644
--- a/patches/external_pr/ui.py
+++ b/patches/external_pr/ui.py
@@ -67,7 +67,8 @@ def save_training_setting(*args):
template_file, use_beta_scheduler, beta_repeat_epoch, epoch_mult, warmup, min_lr, \
gamma_rate, use_beta_adamW_checkbox, save_when_converge, create_when_converge, \
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps, show_gradient_clip_checkbox, \
- gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std = args
+ gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std,\
+ noise_training_scheduler_enabled, noise_training_scheduler_repeat, noise_training_scheduler_cycle = args
dumped_locals = locals()
dumped_locals.pop('args')
filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_train_' + '.json'
@@ -121,6 +122,8 @@ def on_train_gamma_tab(params=None):
label='Show advanced adamW parameter options)')
show_gradient_clip_checkbox = gr.Checkbox(
label='Show Gradient Clipping Options(for both)')
+ show_noise_options = gr.Checkbox(
+ label='Show Noise Scheduler Options(for both)')
with gr.Row(visible=False) as adamW_options:
adamw_weight_decay = gr.Textbox(label="AdamW weight decay parameter", placeholder="default = 0.01",
value="0.01")
@@ -146,7 +149,16 @@ def on_train_gamma_tab(params=None):
gradient_clip_opt = gr.Radio(label="Gradient Clipping Options", choices=["None", "limit", "norm"])
optional_gradient_clip_value = gr.Textbox(label="Limiting value", value="1e-1")
optional_gradient_norm_type = gr.Textbox(label="Norm type", value="2")
+ with gr.Row(visible=False) as noise_scheduler_options:
+ noise_training_scheduler_enabled = gr.Checkbox(label="Use Noise training scheduler(test)")
+ noise_training_scheduler_repeat = gr.Checkbox(label="Restarts noise scheduler, or linear")
+ noise_training_scheduler_cycle = gr.Number(label="Restarts noise scheduler every nth epoch")
# change by feedback
+ show_noise_options.change(
+ fn = lambda show:gr_show(show),
+ inputs = [show_noise_options],
+ outputs = [noise_scheduler_options]
+ )
use_beta_adamW_checkbox.change(
fn=lambda show: gr_show(show),
inputs=[use_beta_adamW_checkbox],
@@ -239,7 +251,10 @@ def on_train_gamma_tab(params=None):
gradient_clip_opt,
optional_gradient_clip_value,
optional_gradient_norm_type,
- latent_sampling_std_value],
+ latent_sampling_std_value,
+ noise_training_scheduler_enabled,
+ noise_training_scheduler_repeat,
+ noise_training_scheduler_cycle],
outputs=[
ti_output,
ti_outcome,
@@ -335,6 +350,9 @@ def on_train_gamma_tab(params=None):
optional_gradient_clip_value,
optional_gradient_norm_type,
latent_sampling_std_value,
+ noise_training_scheduler_enabled,
+ noise_training_scheduler_repeat,
+ noise_training_scheduler_cycle,
load_training_option
],
diff --git a/patches/hypernetwork.py b/patches/hypernetwork.py
index cbb0fc5..ddc2d15 100644
--- a/patches/hypernetwork.py
+++ b/patches/hypernetwork.py
@@ -270,6 +270,7 @@ class Hypernetwork:
self.optional_info = kwargs.get('optional_info', None)
self.skip_connection = kwargs.get('skip_connection', False)
self.upsample_linear = kwargs.get('upsample_linear', None)
+ self.training = False
generation_seed = kwargs.get('generation_seed', None)
normal_std = kwargs.get('normal_std', 0.01)
if self.dropout_structure is None:
@@ -287,6 +288,7 @@ class Hypernetwork:
self.eval()
def weights(self, train=False):
+ self.training = train
res = []
for k, layers in self.layers.items():
for layer in layers:
@@ -294,12 +296,14 @@ class Hypernetwork:
return res
def eval(self):
+ self.training = False
for k, layers in self.layers.items():
for layer in layers:
layer.eval()
layer.set_train(False)
def train(self, mode=True):
+ self.training = mode
for k, layers in self.layers.items():
for layer in layers:
layer.set_train(mode)