83 lines
3.0 KiB
Python
83 lines
3.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
from ...dist_utils import master_only
|
||
from ..hook import HOOKS
|
||
from .base import LoggerHook
|
||
|
||
|
||
@HOOKS.register_module()
|
||
class NeptuneLoggerHook(LoggerHook):
|
||
"""Class to log metrics to NeptuneAI.
|
||
|
||
It requires `neptune-client` to be installed.
|
||
|
||
Args:
|
||
init_kwargs (dict): a dict contains the initialization keys as below:
|
||
- project (str): Name of a project in a form of
|
||
namespace/project_name. If None, the value of
|
||
NEPTUNE_PROJECT environment variable will be taken.
|
||
- api_token (str): User’s API token.
|
||
If None, the value of NEPTUNE_API_TOKEN environment
|
||
variable will be taken. Note: It is strongly recommended
|
||
to use NEPTUNE_API_TOKEN environment variable rather than
|
||
placing your API token in plain text in your source code.
|
||
- name (str, optional, default is 'Untitled'): Editable name of
|
||
the run. Name is displayed in the run's Details and in
|
||
Runs table as a column.
|
||
Check https://docs.neptune.ai/api-reference/neptune#init for
|
||
more init arguments.
|
||
interval (int): Logging interval (every k iterations).
|
||
ignore_last (bool): Ignore the log of last iterations in each epoch
|
||
if less than `interval`.
|
||
reset_flag (bool): Whether to clear the output buffer after logging
|
||
by_epoch (bool): Whether EpochBasedRunner is used.
|
||
|
||
.. _NeptuneAI:
|
||
https://docs.neptune.ai/you-should-know/logging-metadata
|
||
"""
|
||
|
||
def __init__(self,
|
||
init_kwargs=None,
|
||
interval=10,
|
||
ignore_last=True,
|
||
reset_flag=True,
|
||
with_step=True,
|
||
by_epoch=True):
|
||
|
||
super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
|
||
reset_flag, by_epoch)
|
||
self.import_neptune()
|
||
self.init_kwargs = init_kwargs
|
||
self.with_step = with_step
|
||
|
||
def import_neptune(self):
|
||
try:
|
||
import neptune.new as neptune
|
||
except ImportError:
|
||
raise ImportError(
|
||
'Please run "pip install neptune-client" to install neptune')
|
||
self.neptune = neptune
|
||
self.run = None
|
||
|
||
@master_only
|
||
def before_run(self, runner):
|
||
if self.init_kwargs:
|
||
self.run = self.neptune.init(**self.init_kwargs)
|
||
else:
|
||
self.run = self.neptune.init()
|
||
|
||
@master_only
|
||
def log(self, runner):
|
||
tags = self.get_loggable_tags(runner)
|
||
if tags:
|
||
for tag_name, tag_value in tags.items():
|
||
if self.with_step:
|
||
self.run[tag_name].log(
|
||
tag_value, step=self.get_iter(runner))
|
||
else:
|
||
tags['global_step'] = self.get_iter(runner)
|
||
self.run[tag_name].log(tags)
|
||
|
||
@master_only
|
||
def after_run(self, runner):
|
||
self.run.stop()
|