175 lines
5.1 KiB
Python
175 lines
5.1 KiB
Python
import bisect
|
|
import functools
|
|
import logging
|
|
import numbers
|
|
import os
|
|
import signal
|
|
import sys
|
|
import traceback
|
|
import warnings
|
|
|
|
import torch
|
|
from pytorch_lightning import seed_everything
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
def check_and_warn_input_range(tensor, min_value, max_value, name):
|
|
actual_min = tensor.min()
|
|
actual_max = tensor.max()
|
|
if actual_min < min_value or actual_max > max_value:
|
|
warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
|
|
|
|
|
|
def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
|
|
for k, v in cur_dict.items():
|
|
target_key = prefix + k
|
|
target[target_key] = target.get(target_key, default) + v
|
|
|
|
|
|
def average_dicts(dict_list):
|
|
result = {}
|
|
norm = 1e-3
|
|
for dct in dict_list:
|
|
sum_dict_with_prefix(result, dct, '')
|
|
norm += 1
|
|
for k in list(result):
|
|
result[k] /= norm
|
|
return result
|
|
|
|
|
|
def add_prefix_to_keys(dct, prefix):
|
|
return {prefix + k: v for k, v in dct.items()}
|
|
|
|
|
|
def set_requires_grad(module, value):
|
|
for param in module.parameters():
|
|
param.requires_grad = value
|
|
|
|
|
|
def flatten_dict(dct):
|
|
result = {}
|
|
for k, v in dct.items():
|
|
if isinstance(k, tuple):
|
|
k = '_'.join(k)
|
|
if isinstance(v, dict):
|
|
for sub_k, sub_v in flatten_dict(v).items():
|
|
result[f'{k}_{sub_k}'] = sub_v
|
|
else:
|
|
result[k] = v
|
|
return result
|
|
|
|
|
|
class LinearRamp:
|
|
def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
|
|
self.start_value = start_value
|
|
self.end_value = end_value
|
|
self.start_iter = start_iter
|
|
self.end_iter = end_iter
|
|
|
|
def __call__(self, i):
|
|
if i < self.start_iter:
|
|
return self.start_value
|
|
if i >= self.end_iter:
|
|
return self.end_value
|
|
part = (i - self.start_iter) / (self.end_iter - self.start_iter)
|
|
return self.start_value * (1 - part) + self.end_value * part
|
|
|
|
|
|
class LadderRamp:
|
|
def __init__(self, start_iters, values):
|
|
self.start_iters = start_iters
|
|
self.values = values
|
|
assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
|
|
|
|
def __call__(self, i):
|
|
segment_i = bisect.bisect_right(self.start_iters, i)
|
|
return self.values[segment_i]
|
|
|
|
|
|
def get_ramp(kind='ladder', **kwargs):
|
|
if kind == 'linear':
|
|
return LinearRamp(**kwargs)
|
|
if kind == 'ladder':
|
|
return LadderRamp(**kwargs)
|
|
raise ValueError(f'Unexpected ramp kind: {kind}')
|
|
|
|
|
|
def print_traceback_handler(sig, frame):
|
|
LOGGER.warning(f'Received signal {sig}')
|
|
bt = ''.join(traceback.format_stack())
|
|
LOGGER.warning(f'Requested stack trace:\n{bt}')
|
|
|
|
|
|
def register_debug_signal_handlers(sig=None, handler=print_traceback_handler):
|
|
LOGGER.warning(f'Setting signal {sig} handler {handler}')
|
|
signal.signal(sig, handler)
|
|
|
|
|
|
def handle_deterministic_config(config):
|
|
seed = dict(config).get('seed', None)
|
|
if seed is None:
|
|
return False
|
|
|
|
seed_everything(seed)
|
|
return True
|
|
|
|
|
|
def get_shape(t):
|
|
if torch.is_tensor(t):
|
|
return tuple(t.shape)
|
|
elif isinstance(t, dict):
|
|
return {n: get_shape(q) for n, q in t.items()}
|
|
elif isinstance(t, (list, tuple)):
|
|
return [get_shape(q) for q in t]
|
|
elif isinstance(t, numbers.Number):
|
|
return type(t)
|
|
else:
|
|
raise ValueError('unexpected type {}'.format(type(t)))
|
|
|
|
|
|
def get_has_ddp_rank():
|
|
master_port = os.environ.get('MASTER_PORT', None)
|
|
node_rank = os.environ.get('NODE_RANK', None)
|
|
local_rank = os.environ.get('LOCAL_RANK', None)
|
|
world_size = os.environ.get('WORLD_SIZE', None)
|
|
has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
|
|
return has_rank
|
|
|
|
|
|
def handle_ddp_subprocess():
|
|
def main_decorator(main_func):
|
|
@functools.wraps(main_func)
|
|
def new_main(*args, **kwargs):
|
|
# Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
|
|
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
|
|
has_parent = parent_cwd is not None
|
|
has_rank = get_has_ddp_rank()
|
|
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
|
|
|
|
if has_parent:
|
|
# we are in the worker
|
|
sys.argv.extend([
|
|
f'hydra.run.dir={parent_cwd}',
|
|
# 'hydra/hydra_logging=disabled',
|
|
# 'hydra/job_logging=disabled'
|
|
])
|
|
# do nothing if this is a top-level process
|
|
# TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
|
|
|
|
main_func(*args, **kwargs)
|
|
return new_main
|
|
return main_decorator
|
|
|
|
|
|
def handle_ddp_parent_process():
|
|
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
|
|
has_parent = parent_cwd is not None
|
|
has_rank = get_has_ddp_rank()
|
|
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
|
|
|
|
if parent_cwd is None:
|
|
os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
|
|
|
|
return has_parent
|