From 191eaf8d3db0ba3d7373b0371286f7e0d954ae2d Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Mon, 28 Nov 2022 03:54:31 +0900 Subject: [PATCH] Try fix https://github.com/aria1th/Hypernetwork-MonkeyPatch-Extension/issues/6 --- patches/external_pr/hypernetwork.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 4b6b051..6755cc0 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -210,7 +210,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") - scaler.step(optimizer) + try: + scaler.step(optimizer) + except AssertionError: + optimizer.param_groups[0]['capturable'] = True scaler.update() hypernetwork.step += 1 pbar.update()