aria1th 2022-11-28 03:54:31 +09:00
parent 6f54eeb65e
commit 191eaf8d3d
1 changed files with 4 additions and 1 deletions

View File

@ -210,7 +210,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
try:
scaler.step(optimizer) scaler.step(optimizer)
except AssertionError:
optimizer.param_groups[0]['capturable'] = True
scaler.update() scaler.update()
hypernetwork.step += 1 hypernetwork.step += 1
pbar.update() pbar.update()