From a87e90e0dd8a22c530a52ccdbf2eb440b02ae212 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Sun, 27 Nov 2022 12:36:21 +0900 Subject: [PATCH] clean up code partially --- patches/external_pr/hypernetwork.py | 4 ++-- patches/hypernetwork.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 69cd736..5b22281 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -121,11 +121,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, # Here we use optimizer from saved HN, or we can specify as UI option. if hypernetwork.optimizer_name in optimizer_dict: - optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate) + optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) optimizer_name = hypernetwork.optimizer_name else: print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") - optimizer = torch.optim.AdamW(params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate) + optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) optimizer_name = 'AdamW' if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. diff --git a/patches/hypernetwork.py b/patches/hypernetwork.py index 28a0093..3447739 100644 --- a/patches/hypernetwork.py +++ b/patches/hypernetwork.py @@ -460,11 +460,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # Here we use optimizer from saved HN, or we can specify as UI option. if hypernetwork.optimizer_name in optimizer_dict: - optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate) + optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) optimizer_name = hypernetwork.optimizer_name else: print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") - optimizer: torch.optim.Optimizer = torch.optim.AdamW(params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate) + optimizer: torch.optim.Optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) optimizer_name = 'AdamW' if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.