virtex.utils.distributed
A collection of common utilities for distributed training. These are a bunch of
wrappers over utilities from torch.distributed
module, but they do not
raise exceptions in absence of distributed training / CPU-only training, and
fall back to sensible default behavior.
- virtex.utils.distributed.launch(job_fn: Callable, num_machines: int = 1, num_gpus_per_machine: int = 1, machine_rank: int = 0, dist_url: str = 'tcp://127.0.0.1:23456', args=())[source]
Launch a job in a distributed fashion: given
num_machines
machines, each withnum_gpus_per_machine
GPUs, this utility will launch one process per GPU. This wrapper usestorch.multiprocessing.spawn()
.The user has to launch one job on each machine, manually specifying a machine rank (incrementing integers from 0), this utility will adjust process ranks per machine. One process on
machine_rank = 0
will be refered as the master process, and the IP + a free port on this machine will serve as the distributed process communication URL.Default arguments imply one machine with one GPU, and communication URL as
localhost
.Note
This utility assumes same number of GPUs per machine with IDs as
(0, 1, 2 ...)
. If you do not wish to use all GPUs on a machine, setCUDA_VISIBLE_DEVICES
environment variable (for example,CUDA_VISIBLE_DEVICES=5,6
, which restricts to GPU 5 and 6 and re-assigns their IDs to 0 and 1 in this job scope).- Parameters
job_fn – A callable object to launch. Pass your main function doing training, validation etc. here.
num_machines – Number of machines, each with
num_gpus_per_machine
GPUs.num_gpus_per_machine – Number of GPUs per machine, with IDs as
(0, 1, 2 ...)
.machine_rank – A manually specified rank of the machine, serves as a unique identifier and useful for assigning global ranks to processes.
dist_url – Disributed process communication URL as
tcp://x.x.x.x:port
. Set this as the IP (and a free port) of machine with rank 0.args – Arguments to be passed to
job_fn
.
- virtex.utils.distributed._job_worker(local_rank: int, job_fn: Callable, world_size: int, num_gpus_per_machine: int, machine_rank: int, dist_url: str, args: Tuple)[source]
Single distibuted process worker. This should never be used directly, only used by
launch()
.
- virtex.utils.distributed.synchronize() None [source]
Synchronize (barrier) all processes in a process group.
- virtex.utils.distributed.get_world_size() int [source]
Return number of processes in the process group, each uses 1 GPU.
- virtex.utils.distributed.get_rank() int [source]
Return rank of current process in the process group.
- virtex.utils.distributed.is_master_process() bool [source]
Check whether current process is the master process. This check is useful to restrict logging and checkpointing to master process. It will always return
True
for single machine, single GPU execution.
- virtex.utils.distributed.average_across_processes(t: Union[torch.Tensor, Dict[str, torch.Tensor]])[source]
Averages a tensor, or a dict of tensors across all processes in a process group. Objects in all processes will finally have same mean value.
Note
Nested dicts of tensors are not supported.
- Parameters
t – torch.Tensor or Dict[str, torch.Tensor] A tensor or dict of tensors to average across processes.