add noise scheduler for training
move to LatentDiffusion forward unbind scheduler unbind noise scheduler by option minor fix? minor fix for extreme casenoise-scheduler
parent
4b5c5cf800
commit
a5182880f9
|
|
@ -0,0 +1,63 @@
|
|||
import torch
|
||||
import ldm.models.diffusion.ddpm
|
||||
from modules import shared
|
||||
|
||||
|
||||
class Scheduler:
|
||||
""" Proportional Noise Step Scheduler"""
|
||||
def __init__(self, cycle_step=128, repeat=True):
|
||||
self.cycle_step = int(cycle_step)
|
||||
self.repeat = repeat
|
||||
self.run_assertion()
|
||||
|
||||
def __call__(self, value, step):
|
||||
if self.repeat:
|
||||
step %= self.cycle_step
|
||||
return max(1, int(value * step / self.cycle_step))
|
||||
else:
|
||||
return value if step >= self.cycle_step else max(1, int(value * step / self.cycle_step))
|
||||
|
||||
def run_assertion(self):
|
||||
assert type(self.cycle_step) is int
|
||||
assert type(self.repeat) is bool
|
||||
assert not self.repeat or self.cycle_step > 0
|
||||
|
||||
def set(self, cycle_step=-1, repeat=-1):
|
||||
if cycle_step >= 0:
|
||||
self.cycle_step = int(cycle_step)
|
||||
if repeat != -1:
|
||||
self.repeat = repeat
|
||||
self.run_assertion()
|
||||
|
||||
|
||||
training_scheduler = Scheduler(cycle_step=-1, repeat=False)
|
||||
|
||||
|
||||
def get_current(value, step=None):
|
||||
if step is None:
|
||||
if hasattr(shared.loaded_hypernetwork, 'step') and shared.loaded_hypernetwork.training and shared.loaded_hypernetwork.step is not None:
|
||||
return training_scheduler(value, shared.loaded_hypernetwork.step)
|
||||
return value
|
||||
return max(1, training_scheduler(value, step))
|
||||
|
||||
|
||||
def set_scheduler(cycle_step, repeat):
|
||||
global training_scheduler
|
||||
training_scheduler.set(cycle_step, repeat)
|
||||
|
||||
|
||||
def forward(self, x, c, *args, **kwargs):
|
||||
t = torch.randint(0, get_current(self.num_timesteps), (x.shape[0],), device=self.device).long()
|
||||
if self.model.conditioning_key is not None:
|
||||
assert c is not None
|
||||
if self.cond_stage_trainable:
|
||||
c = self.get_learned_conditioning(c)
|
||||
if self.shorten_cond_schedule: # TODO: drop this option
|
||||
tc = self.cond_ids[t].to(self.device)
|
||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||
return self.p_losses(x, c, t, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
ldm.models.diffusion.ddpm.LatentDiffusion.forward = forward
|
||||
|
|
@ -22,15 +22,18 @@ from ..hnutil import optim_to
|
|||
from ..ui import create_hypernetwork_load
|
||||
from ..scheduler import CosineAnnealingWarmUpRestarts
|
||||
from .dataset import PersonalizedBase, PersonalizedDataLoader
|
||||
from ..ddpm_hijack import set_scheduler
|
||||
|
||||
|
||||
def get_training_option(filename):
|
||||
print(filename)
|
||||
if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)):
|
||||
if os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename)) and os.path.isfile(
|
||||
os.path.join(shared.cmd_opts.hypernetwork_dir, filename)):
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename)
|
||||
elif os.path.exists(filename) and os.path.isfile(filename):
|
||||
filename = filename
|
||||
elif os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')) and os.path.isfile(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')):
|
||||
elif os.path.exists(os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')) and os.path.isfile(
|
||||
os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')):
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, filename + '.json')
|
||||
else:
|
||||
return False
|
||||
|
|
@ -52,7 +55,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
adamw_eps=1e-8,
|
||||
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01,
|
||||
optional_gradient_norm_type=2, latent_sampling_std=-1,
|
||||
load_training_options=''):
|
||||
noise_training_scheduler_enabled=False, noise_training_scheduler_repeat=False, noise_training_scheduler_cycle=128,
|
||||
load_training_options=''
|
||||
):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
if load_training_options != '':
|
||||
|
|
@ -87,6 +92,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
optional_gradient_clip_value = dump['optional_gradient_clip_value']
|
||||
optional_gradient_norm_type = dump['optional_gradient_norm_type']
|
||||
latent_sampling_std = dump.get('latent_sampling_std', -1)
|
||||
noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False)
|
||||
noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False)
|
||||
noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128)
|
||||
try:
|
||||
if use_adamw_parameter:
|
||||
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in
|
||||
|
|
@ -132,6 +140,11 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
create_when_converge = False
|
||||
except ValueError:
|
||||
raise RuntimeError("Cannot use advanced LR scheduler settings!")
|
||||
if noise_training_scheduler_enabled:
|
||||
set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat)
|
||||
print(f"Noise training scheduler is now ready for {noise_training_scheduler_cycle}, {noise_training_scheduler_repeat}!")
|
||||
else:
|
||||
set_scheduler(-1, False)
|
||||
if use_grad_opts and gradient_clip_opt != "None":
|
||||
try:
|
||||
optional_gradient_clip_value = float(optional_gradient_clip_value)
|
||||
|
|
@ -324,7 +337,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||
loss = shared.sd_model(x, c)[0]
|
||||
for filenames in batch.filename:
|
||||
loss_dict[filenames].append(loss.item())
|
||||
loss_dict[filenames].append(loss.detach().item())
|
||||
loss /= gradient_step
|
||||
del x
|
||||
del c
|
||||
|
|
@ -382,10 +395,12 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values())
|
||||
tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step,
|
||||
learn_rate=scheduler.learn_rate if not use_beta_scheduler else optimizer.param_groups[0]['lr'], epoch_num=epoch_num)
|
||||
learn_rate=scheduler.learn_rate if not use_beta_scheduler else
|
||||
optimizer.param_groups[0]['lr'], epoch_num=epoch_num)
|
||||
if images_dir is not None and (
|
||||
use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or (
|
||||
create_image_every > 0 and steps_done % create_image_every == 0):
|
||||
set_scheduler(-1, False)
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
rng_state = torch.get_rng_state()
|
||||
|
|
@ -437,6 +452,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
hypernetwork.train()
|
||||
if move_optimizer:
|
||||
optim_to(optimizer, devices.device)
|
||||
if noise_training_scheduler_enabled:
|
||||
set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat)
|
||||
if image is not None:
|
||||
if hasattr(shared.state, 'assign_current_image'):
|
||||
shared.state.assign_current_image(image)
|
||||
|
|
@ -469,6 +486,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
if hasattr(sd_hijack_checkpoint, 'remove'):
|
||||
sd_hijack_checkpoint.remove()
|
||||
set_scheduler(-1, False)
|
||||
report_statistics(loss_dict)
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
|
|
@ -488,7 +506,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
|
||||
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
|
||||
move_optimizer=True,
|
||||
load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1, setting_tuple=None):
|
||||
load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1,
|
||||
setting_tuple=None):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
base_hypernetwork_name = hypernetwork_name
|
||||
|
|
@ -514,10 +533,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
skip_connection = dump_hyper['skip_connection']
|
||||
hypernetwork = create_hypernetwork_load(hypernetwork_name, enable_sizes, overwrite_old, layer_structure,
|
||||
activation_func, weight_init, add_layer_norm, use_dropout,
|
||||
dropout_structure, optional_info, weight_init_seed, normal_std, skip_connection)
|
||||
dropout_structure, optional_info, weight_init_seed, normal_std,
|
||||
skip_connection)
|
||||
else:
|
||||
load_hypernetwork(hypernetwork_name)
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(',1)[0] + setting_suffix
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] + setting_suffix
|
||||
shared.loaded_hypernetwork.save(os.path.join(shared.cmd_opts.hypernetwork_dir, f"{hypernetwork_name}.pt"))
|
||||
shared.reload_hypernetworks()
|
||||
load_hypernetwork(hypernetwork_name)
|
||||
|
|
@ -552,6 +572,9 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
optional_gradient_clip_value = dump['optional_gradient_clip_value']
|
||||
optional_gradient_norm_type = dump['optional_gradient_norm_type']
|
||||
latent_sampling_std = dump.get('latent_sampling_std', -1)
|
||||
noise_training_scheduler_enabled = dump.get('noise_training_scheduler_enabled', False)
|
||||
noise_training_scheduler_repeat = dump.get('noise_training_scheduler_repeat', False)
|
||||
noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128)
|
||||
else:
|
||||
raise RuntimeError(f"Cannot load from {load_training_options}!")
|
||||
else:
|
||||
|
|
@ -627,6 +650,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
else:
|
||||
def gradient_clipping(arg1):
|
||||
return
|
||||
if noise_training_scheduler_enabled:
|
||||
set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat)
|
||||
print(f"Noise training scheduler is now ready for {noise_training_scheduler_cycle}, {noise_training_scheduler_repeat}!")
|
||||
else:
|
||||
set_scheduler(-1, False)
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
|
||||
|
|
@ -794,7 +822,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||
loss = shared.sd_model(x, c)[0]
|
||||
for filenames in batch.filename:
|
||||
loss_dict[filenames].append(loss.item())
|
||||
loss_dict[filenames].append(loss.detach().item())
|
||||
loss /= gradient_step
|
||||
del x
|
||||
del c
|
||||
|
|
@ -852,10 +880,12 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values())
|
||||
tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step,
|
||||
learn_rate=scheduler.learn_rate if not use_beta_scheduler else optimizer.param_groups[0]['lr'], epoch_num=epoch_num,base_name=hypernetwork_name)
|
||||
learn_rate=scheduler.learn_rate if not use_beta_scheduler else
|
||||
optimizer.param_groups[0]['lr'], epoch_num=epoch_num, base_name=hypernetwork_name)
|
||||
if images_dir is not None and (
|
||||
use_beta_scheduler and scheduler_beta.is_EOC(hypernetwork.step) and create_when_converge) or (
|
||||
create_image_every > 0 and steps_done % create_image_every == 0):
|
||||
set_scheduler(-1, False)
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
rng_state = torch.get_rng_state()
|
||||
|
|
@ -907,6 +937,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
hypernetwork.train()
|
||||
if move_optimizer:
|
||||
optim_to(optimizer, devices.device)
|
||||
if noise_training_scheduler_enabled:
|
||||
set_scheduler(noise_training_scheduler_cycle, noise_training_scheduler_repeat)
|
||||
if image is not None:
|
||||
if hasattr(shared.state, 'assign_current_image'):
|
||||
shared.state.assign_current_image(image)
|
||||
|
|
@ -936,10 +968,12 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
pbar.leave = False
|
||||
pbar.close()
|
||||
hypernetwork.eval()
|
||||
set_scheduler(-1, False)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
if hasattr(sd_hijack_checkpoint, 'remove'):
|
||||
sd_hijack_checkpoint.remove()
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) if sum(len(x) for x in loss_dict.values()) > 0 else 0
|
||||
tensorboard_log_hyperparameter(tensorboard_writer, lr=learn_rate,
|
||||
GA_steps=gradient_step,
|
||||
batch_size=batch_size,
|
||||
|
|
@ -965,7 +999,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
gradient_clip_value=optional_gradient_clip_value,
|
||||
gradient_clip_norm_type=optional_gradient_norm_type,
|
||||
loss=mean_loss,
|
||||
base_hypernetwork_name= hypernetwork_name
|
||||
base_hypernetwork_name=hypernetwork_name
|
||||
)
|
||||
report_statistics(loss_dict)
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
|
|
@ -1012,4 +1046,3 @@ def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directo
|
|||
load_hypernetworks_option, load_training_option, manual_dataset_seed, setting_tuple=(_i, _j))
|
||||
if shared.state.interrupted:
|
||||
return None, None
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,8 @@ def save_training_setting(*args):
|
|||
template_file, use_beta_scheduler, beta_repeat_epoch, epoch_mult, warmup, min_lr, \
|
||||
gamma_rate, use_beta_adamW_checkbox, save_when_converge, create_when_converge, \
|
||||
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps, show_gradient_clip_checkbox, \
|
||||
gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std = args
|
||||
gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std,\
|
||||
noise_training_scheduler_enabled, noise_training_scheduler_repeat, noise_training_scheduler_cycle = args
|
||||
dumped_locals = locals()
|
||||
dumped_locals.pop('args')
|
||||
filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_train_' + '.json'
|
||||
|
|
@ -121,6 +122,8 @@ def on_train_gamma_tab(params=None):
|
|||
label='Show advanced adamW parameter options)')
|
||||
show_gradient_clip_checkbox = gr.Checkbox(
|
||||
label='Show Gradient Clipping Options(for both)')
|
||||
show_noise_options = gr.Checkbox(
|
||||
label='Show Noise Scheduler Options(for both)')
|
||||
with gr.Row(visible=False) as adamW_options:
|
||||
adamw_weight_decay = gr.Textbox(label="AdamW weight decay parameter", placeholder="default = 0.01",
|
||||
value="0.01")
|
||||
|
|
@ -146,7 +149,16 @@ def on_train_gamma_tab(params=None):
|
|||
gradient_clip_opt = gr.Radio(label="Gradient Clipping Options", choices=["None", "limit", "norm"])
|
||||
optional_gradient_clip_value = gr.Textbox(label="Limiting value", value="1e-1")
|
||||
optional_gradient_norm_type = gr.Textbox(label="Norm type", value="2")
|
||||
with gr.Row(visible=False) as noise_scheduler_options:
|
||||
noise_training_scheduler_enabled = gr.Checkbox(label="Use Noise training scheduler(test)")
|
||||
noise_training_scheduler_repeat = gr.Checkbox(label="Restarts noise scheduler, or linear")
|
||||
noise_training_scheduler_cycle = gr.Number(label="Restarts noise scheduler every nth epoch")
|
||||
# change by feedback
|
||||
show_noise_options.change(
|
||||
fn = lambda show:gr_show(show),
|
||||
inputs = [show_noise_options],
|
||||
outputs = [noise_scheduler_options]
|
||||
)
|
||||
use_beta_adamW_checkbox.change(
|
||||
fn=lambda show: gr_show(show),
|
||||
inputs=[use_beta_adamW_checkbox],
|
||||
|
|
@ -239,7 +251,10 @@ def on_train_gamma_tab(params=None):
|
|||
gradient_clip_opt,
|
||||
optional_gradient_clip_value,
|
||||
optional_gradient_norm_type,
|
||||
latent_sampling_std_value],
|
||||
latent_sampling_std_value,
|
||||
noise_training_scheduler_enabled,
|
||||
noise_training_scheduler_repeat,
|
||||
noise_training_scheduler_cycle],
|
||||
outputs=[
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
|
|
@ -335,6 +350,9 @@ def on_train_gamma_tab(params=None):
|
|||
optional_gradient_clip_value,
|
||||
optional_gradient_norm_type,
|
||||
latent_sampling_std_value,
|
||||
noise_training_scheduler_enabled,
|
||||
noise_training_scheduler_repeat,
|
||||
noise_training_scheduler_cycle,
|
||||
load_training_option
|
||||
|
||||
],
|
||||
|
|
|
|||
|
|
@ -270,6 +270,7 @@ class Hypernetwork:
|
|||
self.optional_info = kwargs.get('optional_info', None)
|
||||
self.skip_connection = kwargs.get('skip_connection', False)
|
||||
self.upsample_linear = kwargs.get('upsample_linear', None)
|
||||
self.training = False
|
||||
generation_seed = kwargs.get('generation_seed', None)
|
||||
normal_std = kwargs.get('normal_std', 0.01)
|
||||
if self.dropout_structure is None:
|
||||
|
|
@ -287,6 +288,7 @@ class Hypernetwork:
|
|||
self.eval()
|
||||
|
||||
def weights(self, train=False):
|
||||
self.training = train
|
||||
res = []
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
|
|
@ -294,12 +296,14 @@ class Hypernetwork:
|
|||
return res
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.eval()
|
||||
layer.set_train(False)
|
||||
|
||||
def train(self, mode=True):
|
||||
self.training = mode
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.set_train(mode)
|
||||
|
|
|
|||
Loading…
Reference in New Issue