57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from ...dist_utils import master_only
|
|
from ..hook import HOOKS
|
|
from .base import LoggerHook
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class WandbLoggerHook(LoggerHook):
|
|
|
|
def __init__(self,
|
|
init_kwargs=None,
|
|
interval=10,
|
|
ignore_last=True,
|
|
reset_flag=False,
|
|
commit=True,
|
|
by_epoch=True,
|
|
with_step=True):
|
|
super(WandbLoggerHook, self).__init__(interval, ignore_last,
|
|
reset_flag, by_epoch)
|
|
self.import_wandb()
|
|
self.init_kwargs = init_kwargs
|
|
self.commit = commit
|
|
self.with_step = with_step
|
|
|
|
def import_wandb(self):
|
|
try:
|
|
import wandb
|
|
except ImportError:
|
|
raise ImportError(
|
|
'Please run "pip install wandb" to install wandb')
|
|
self.wandb = wandb
|
|
|
|
@master_only
|
|
def before_run(self, runner):
|
|
super(WandbLoggerHook, self).before_run(runner)
|
|
if self.wandb is None:
|
|
self.import_wandb()
|
|
if self.init_kwargs:
|
|
self.wandb.init(**self.init_kwargs)
|
|
else:
|
|
self.wandb.init()
|
|
|
|
@master_only
|
|
def log(self, runner):
|
|
tags = self.get_loggable_tags(runner)
|
|
if tags:
|
|
if self.with_step:
|
|
self.wandb.log(
|
|
tags, step=self.get_iter(runner), commit=self.commit)
|
|
else:
|
|
tags['global_step'] = self.get_iter(runner)
|
|
self.wandb.log(tags, commit=self.commit)
|
|
|
|
@master_only
|
|
def after_run(self, runner):
|
|
self.wandb.join()
|