control growth factor
parent
5411bcce11
commit
d0abef3abf
|
|
@ -60,7 +60,7 @@ def get_training_option(filename):
|
|||
return obj
|
||||
|
||||
|
||||
def prepare_training_hypernetwork(hypernetwork_name, learn_rate=0.1, use_adamw_parameter=False, use_dadaptation=False, **adamW_kwarg_dict):
|
||||
def prepare_training_hypernetwork(hypernetwork_name, learn_rate=0.1, use_adamw_parameter=False, use_dadaptation=False, dadapt_growth_factor=-1, **adamW_kwarg_dict):
|
||||
""" returns hypernetwork object binded with optimizer"""
|
||||
hypernetwork = load_hypernetwork(hypernetwork_name)
|
||||
hypernetwork.to(devices.device)
|
||||
|
|
@ -86,7 +86,7 @@ def prepare_training_hypernetwork(hypernetwork_name, learn_rate=0.1, use_adamw_p
|
|||
optim_class = get_dadapt_adam(hypernetwork.optimizer_name)
|
||||
if optim_class != torch.optim.AdamW:
|
||||
print('Optimizer class is ' + str(optim_class))
|
||||
optimizer = optim_class(params=weights, lr=learn_rate, decouple=True, **adamW_kwarg_dict)
|
||||
optimizer = optim_class(params=weights, lr=learn_rate, decouple=True, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, **adamW_kwarg_dict)
|
||||
hypernetwork.optimizer_name = 'DAdaptAdamW'
|
||||
else:
|
||||
optimizer = torch.optim.AdamW(params=weights, lr=learn_rate, **adamW_kwarg_dict)
|
||||
|
|
@ -101,7 +101,7 @@ def prepare_training_hypernetwork(hypernetwork_name, learn_rate=0.1, use_adamw_p
|
|||
from .dadapt_test.install import get_dadapt_adam
|
||||
optim_class = get_dadapt_adam(hypernetwork.optimizer_name)
|
||||
if optim_class != torch.optim.AdamW:
|
||||
optimizer = optim_class(params=weights, lr=learn_rate, decouple=True, **adamW_kwarg_dict)
|
||||
optimizer = optim_class(params=weights, lr=learn_rate, decouple=True, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, **adamW_kwarg_dict)
|
||||
optimizer_name = 'DAdaptAdamW'
|
||||
hypernetwork.optimizer_name = 'DAdaptAdamW'
|
||||
if optimizer is None:
|
||||
|
|
@ -131,7 +131,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
use_grad_opts=False, gradient_clip_opt='None', optional_gradient_clip_value=1e01,
|
||||
optional_gradient_norm_type=2, latent_sampling_std=-1,
|
||||
noise_training_scheduler_enabled=False, noise_training_scheduler_repeat=False, noise_training_scheduler_cycle=128,
|
||||
load_training_options='', loss_opt='loss_simple', use_dadaptation=False
|
||||
load_training_options='', loss_opt='loss_simple', use_dadaptation=False, dadapt_growth_factor=-1
|
||||
):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
|
@ -172,6 +172,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128)
|
||||
loss_opt = dump.get('loss_opt', 'loss_simple')
|
||||
use_dadaptation = dump.get('use_dadaptation', False)
|
||||
dadapt_growth_factor = dump.get('dadapt_growth_factor', -1)
|
||||
try:
|
||||
if use_adamw_parameter:
|
||||
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps = [float(x) for x in
|
||||
|
|
@ -261,7 +262,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
shared.state.job_count = steps
|
||||
tmp_scheduler = LearnRateScheduler(learn_rate, steps, 0)
|
||||
hypernetwork, optimizer, weights, optimizer_name = prepare_training_hypernetwork(hypernetwork_name, tmp_scheduler.learn_rate, use_adamw_parameter, use_dadaptation, **adamW_kwarg_dict)
|
||||
hypernetwork, optimizer, weights, optimizer_name = prepare_training_hypernetwork(hypernetwork_name, tmp_scheduler.learn_rate, use_adamw_parameter, use_dadaptation,dadapt_growth_factor, **adamW_kwarg_dict)
|
||||
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
|
|
@ -640,6 +641,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
noise_training_scheduler_cycle = dump.get('noise_training_scheduler_cycle', 128)
|
||||
loss_opt = dump.get('loss_opt', 'loss_simple')
|
||||
use_dadaptation = dump.get('use_dadaptation', False)
|
||||
dadapt_growth_factor = dump.get('dadapt_growth_factor', -1)
|
||||
else:
|
||||
raise RuntimeError(f"Cannot load from {load_training_options}!")
|
||||
else:
|
||||
|
|
@ -816,7 +818,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
from .dadapt_test.install import get_dadapt_adam
|
||||
optim_class = get_dadapt_adam(hypernetwork.optimizer_name)
|
||||
if optim_class != torch.optim.AdamW:
|
||||
optimizer = optim_class(params=weights, lr=scheduler.learn_rate, decouple=True, **adamW_kwarg_dict)
|
||||
optimizer = optim_class(params=weights, lr=scheduler.learn_rate, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, decouple=True, **adamW_kwarg_dict)
|
||||
else:
|
||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
|
||||
else:
|
||||
|
|
@ -830,7 +832,7 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
from .dadapt_test.install import get_dadapt_adam
|
||||
optim_class = get_dadapt_adam(hypernetwork.optimizer_name)
|
||||
if optim_class != torch.optim.AdamW:
|
||||
optimizer = optim_class(params=weights, lr=scheduler.learn_rate, decouple=True, **adamW_kwarg_dict)
|
||||
optimizer = optim_class(params=weights, lr=scheduler.learn_rate, growth_rate = float('inf') if dadapt_growth_factor < 0 else dadapt_growth_factor, decouple=True, **adamW_kwarg_dict)
|
||||
optimizer_name = 'DAdaptAdamW'
|
||||
if optimizer is None:
|
||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate, **adamW_kwarg_dict)
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ def save_training_setting(*args):
|
|||
gamma_rate, use_beta_adamW_checkbox, save_when_converge, create_when_converge, \
|
||||
adamw_weight_decay, adamw_beta_1, adamw_beta_2, adamw_eps, show_gradient_clip_checkbox, \
|
||||
gradient_clip_opt, optional_gradient_clip_value, optional_gradient_norm_type, latent_sampling_std,\
|
||||
noise_training_scheduler_enabled, noise_training_scheduler_repeat, noise_training_scheduler_cycle, loss_opt, use_dadaptation = args
|
||||
noise_training_scheduler_enabled, noise_training_scheduler_repeat, noise_training_scheduler_cycle, loss_opt, use_dadaptation, dadapt_growth_factor = args
|
||||
dumped_locals = locals()
|
||||
dumped_locals.pop('args')
|
||||
filename = (str(random.randint(0, 1024)) if save_file_name == '' else save_file_name) + '_train_' + '.json'
|
||||
|
|
@ -146,6 +146,8 @@ def on_train_gamma_tab(params=None):
|
|||
adamw_beta_1 = gr.Textbox(label="AdamW beta1 parameter", placeholder="default = 0.9", value="0.9")
|
||||
adamw_beta_2 = gr.Textbox(label="AdamW beta2 parameter", placeholder="default = 0.99", value="0.99")
|
||||
adamw_eps = gr.Textbox(label="AdamW epsilon parameter", placeholder="default = 1e-8", value="1e-8")
|
||||
with gr.Row(visible=False) as dadapt_growth_options:
|
||||
dadapt_growth_factor = gr.Number(value=-1, label='Growth factor limiting, use value like 1.02 or leave it as -1')
|
||||
with gr.Row(visible=False) as beta_scheduler_options:
|
||||
use_beta_scheduler = gr.Checkbox(label='Use CosineAnnealingWarmupRestarts Scheduler')
|
||||
beta_repeat_epoch = gr.Textbox(label='Steps for cycle', placeholder="Cycles every nth Step", value="64")
|
||||
|
|
@ -170,6 +172,11 @@ def on_train_gamma_tab(params=None):
|
|||
noise_training_scheduler_repeat = gr.Checkbox(label="Restarts noise scheduler, or linear")
|
||||
noise_training_scheduler_cycle = gr.Number(label="Restarts noise scheduler every nth epoch")
|
||||
# change by feedback
|
||||
use_dadaptation.change(
|
||||
fn=lambda show: gr_show(show),
|
||||
inputs=[use_dadaptation],
|
||||
outputs=[dadapt_growth_options]
|
||||
)
|
||||
show_noise_options.change(
|
||||
fn = lambda show:gr_show(show),
|
||||
inputs = [show_noise_options],
|
||||
|
|
@ -275,7 +282,8 @@ def on_train_gamma_tab(params=None):
|
|||
noise_training_scheduler_repeat,
|
||||
noise_training_scheduler_cycle,
|
||||
loss_opt,
|
||||
use_dadaptation],
|
||||
use_dadaptation,
|
||||
dadapt_growth_factor],
|
||||
outputs=[
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
|
|
@ -376,7 +384,8 @@ def on_train_gamma_tab(params=None):
|
|||
noise_training_scheduler_cycle,
|
||||
load_training_option,
|
||||
loss_opt,
|
||||
use_dadaptation
|
||||
use_dadaptation,
|
||||
dadapt_growth_factor
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
|
|
|
|||
Loading…
Reference in New Issue