clean up code partially

beta-apply-bigger-batch-sizes
aria1th 2022-11-27 12:36:21 +09:00
parent ad395f11b8
commit a87e90e0dd
2 changed files with 4 additions and 4 deletions

View File

@ -121,11 +121,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate)
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate)
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.

View File

@ -460,11 +460,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate)
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer: torch.optim.Optimizer = torch.optim.AdamW(params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate)
optimizer: torch.optim.Optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.