diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 69cd736..5b22281 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -121,11 +121,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, # Here we use optimizer from saved HN, or we can specify as UI option. if hypernetwork.optimizer_name in optimizer_dict: - optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate) + optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) optimizer_name = hypernetwork.optimizer_name else: print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") - optimizer = torch.optim.AdamW(params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate) + optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) optimizer_name = 'AdamW' if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. diff --git a/patches/hypernetwork.py b/patches/hypernetwork.py index 28a0093..3447739 100644 --- a/patches/hypernetwork.py +++ b/patches/hypernetwork.py @@ -460,11 +460,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # Here we use optimizer from saved HN, or we can specify as UI option. if hypernetwork.optimizer_name in optimizer_dict: - optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate) + optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) optimizer_name = hypernetwork.optimizer_name else: print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") - optimizer: torch.optim.Optimizer = torch.optim.AdamW(params=weights, lr=0 if use_beta_scheduler else scheduler.learn_rate) + optimizer: torch.optim.Optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) optimizer_name = 'AdamW' if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.