58 lines
2.0 KiB
Python
58 lines
2.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
|
|
from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
|
|
from ...dist_utils import master_only
|
|
from ..hook import HOOKS
|
|
from .base import LoggerHook
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class TensorboardLoggerHook(LoggerHook):
|
|
|
|
def __init__(self,
|
|
log_dir=None,
|
|
interval=10,
|
|
ignore_last=True,
|
|
reset_flag=False,
|
|
by_epoch=True):
|
|
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
|
|
reset_flag, by_epoch)
|
|
self.log_dir = log_dir
|
|
|
|
@master_only
|
|
def before_run(self, runner):
|
|
super(TensorboardLoggerHook, self).before_run(runner)
|
|
if (TORCH_VERSION == 'parrots'
|
|
or digit_version(TORCH_VERSION) < digit_version('1.1')):
|
|
try:
|
|
from tensorboardX import SummaryWriter
|
|
except ImportError:
|
|
raise ImportError('Please install tensorboardX to use '
|
|
'TensorboardLoggerHook.')
|
|
else:
|
|
try:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
except ImportError:
|
|
raise ImportError(
|
|
'Please run "pip install future tensorboard" to install '
|
|
'the dependencies to use torch.utils.tensorboard '
|
|
'(applicable to PyTorch 1.1 or higher)')
|
|
|
|
if self.log_dir is None:
|
|
self.log_dir = osp.join(runner.work_dir, 'tf_logs')
|
|
self.writer = SummaryWriter(self.log_dir)
|
|
|
|
@master_only
|
|
def log(self, runner):
|
|
tags = self.get_loggable_tags(runner, allow_text=True)
|
|
for tag, val in tags.items():
|
|
if isinstance(val, str):
|
|
self.writer.add_text(tag, val, self.get_iter(runner))
|
|
else:
|
|
self.writer.add_scalar(tag, val, self.get_iter(runner))
|
|
|
|
@master_only
|
|
def after_run(self, runner):
|
|
self.writer.close()
|