20 lines
603 B
Python
20 lines
603 B
Python
import torch
|
|
import torch.distributed as dist
|
|
|
|
def setup_dist(local_rank):
|
|
if dist.is_initialized():
|
|
return
|
|
torch.cuda.set_device(local_rank)
|
|
torch.distributed.init_process_group(
|
|
'nccl',
|
|
init_method='env://'
|
|
)
|
|
|
|
def gather_data(data, return_np=True):
|
|
''' gather data from multiple processes to one list '''
|
|
data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
|
|
dist.all_gather(data_list, data) # gather not supported with NCCL
|
|
if return_np:
|
|
data_list = [data.cpu().numpy() for data in data_list]
|
|
return data_list
|