Source code for ares.utils.dist

import os
import torch
import numpy as np
import random

[docs]def distributed_init(args): '''This function performs the distributed setting.''' if args.distributed: if args.local_rank !=-1: # for distributed launch args.rank=args.local_rank args.device_id=args.local_rank elif 'SLURM_PROCID' in os.environ: # for slurm scheduler args.rank=int(os.environ['SLURM_PROCID']) args.device_id=args.rank % torch.cuda.device_count() torch.cuda.set_device(args.device_id) torch.distributed.init_process_group(backend=args.dist_backend,init_method=args.dist_url,world_size=args.world_size,rank=args.rank) setup_for_distributed(args.rank==0) else: args.local_rank=0 args.world_size=1 args.rank=0 args.device_id=0 torch.cuda.set_device(args.device_id)
[docs]def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print
[docs]def random_seed(seed=0, rank=0): seed = seed + rank torch.manual_seed(seed) np.random.seed(seed) random.seed(seed)