diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 4b6b051..6755cc0 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -210,7 +210,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") - scaler.step(optimizer) + try: + scaler.step(optimizer) + except AssertionError: + optimizer.param_groups[0]['capturable'] = True scaler.update() hypernetwork.step += 1 pbar.update()