PyTorch: Multi-GPU and multi-node data parallelism

This page explains how to distribute an artificial neural network model implemented in a PyTorch code, according to the data parallelism method.

Here, we are documenting the DistributedDataParallel integrated solution which is the most efficient according to the PyTorch documentation. This is a multi-process parallelism which functions equally well in mono-node and multi-node.

Multi-process configuration with SLURM

For multi-nodes, it is necessary to use multi-processing managed by SLURM (execution via the SLURM command srun). For mono-node, it is possible to use torch.multiprocessing.spawn as indicated in the PyTorch documentation. However, it is possible, and more practical to use SLURM multi-processing in either case, mono-node or multi-node. This is what we will document on this page.

When you launch a script with the SLURM srun command, the script is automatically distributed on all the predefined tasks. For example, if we reserve four 8-GPU nodes and request 3 GPUs per node, we obtain:

  • 4 nodes, indexed from 0 to 3.
  • 3 GPUs/node, indexed from 0 to 2 on each node.
  • 4 x 3 = 12 processes in total, allowing the execution of 12 tasks, with ranks from 0 to 11.

Mutlti-process en SLURM

Illustration of a SLURM reservation of 4 nodes and 3 GPUs per node, equalling 12 processes.
The collective inter-node communications are managed by the NCCL library.

The following are two examples of SLURM scripts for Jean-Zay:

  • For a reservation of N four-GPU V100 nodes via the default GPU partition:
    #!/bin/bash
    #SBATCH --job-name=torch-multi-gpu
    #SBATCH --nodes=N            # total number of nodes (N to be defined)
    #SBATCH --ntasks-per-node=4  # number of tasks per node (here 4 tasks, or 1 task per GPU)
    #SBATCH --gres=gpu:4         # number of GPUs reserved per node (here 4, or all the GPUs)
    #SBATCH --cpus-per-task=10   # number of cores per task (4x10 = 40 cores, or all the cores)
    #SBATCH --hint=nomultithread
    #SBATCH --time=40:00:00
    #SBATCH --output=torch-multi-gpu%j.out
    ##SBATCH --account=abc@v100
     
    module load pytorch-gpu/py3/1.11.0
     
    srun python myscript.py


    Comment : Here, the nodes are reserved exclusively. Of particular note, this gives access to the entire memory of each node.

  • For a reservation of N eight-GPU A100 nodes:
    #!/bin/bash
    #SBATCH --job-name=torch-multi-gpu
    #SBATCH --nodes=N            # total number of nodes (N to be defined)
    #SBATCH --ntasks-per-node=8  # number of tasks per node (here 8 tasks, or 1 task per GPU)
    #SBATCH --gres=gpu:8         # number of GPUs reserved per node (here 8, or all the GPUs)
    #SBATCH --cpus-per-task=8    # number of cores per task (8x8 = 64 cores, or all the cores)
    #SBATCH --hint=nomultithread
    #SBATCH --time=40:00:00
    #SBATCH --output=torch-multi-gpu%j.out
    #SBATCH -C a100
    ##SBATCH --account=abc@a100
     
    module load cpuarch/amd
    module load pytorch-gpu/py3/1.11.0
     
    srun python myscript.py


    Comment : Here, the nodes are reserved exclusively. Of particular note, this gives access to the entire memory of each node.

Implementation of the DistributedDataParallel solution

To implement the DistributedDataParallel solution in PyTorch, it is necessary to:

  1. Define the environment variables linked to the master node.
    • MASTER_ADD: The IP address or the hostname of the node corresponding to task 0 (the first node on the node list). If you are in mono-node, the value localhost is sufficient.
    • MASTER_PORT: The number of a random port. To avoid conflicts, and by convention, we will use a port number between 10001 and 20000 (for example, 12345).
    • On Jean Zay, a library developed by IDRIS has been included in the Pytorch modules to automatically define the MASTER_ADD and MASTER_PORT variables. You simply need to import it to your script:
       import idr_torch 

      This command alone will create the variables. For your information, the following shows what is contained in this script called:

      idr_torch.py
      #!/usr/bin/env python
      # coding: utf-8
       
      import os
      import hostlist
       
      # get SLURM variables
      rank = int(os.environ['SLURM_PROCID'])
      local_rank = int(os.environ['SLURM_LOCALID'])
      size = int(os.environ['SLURM_NTASKS'])
      cpus_per_task = int(os.environ['SLURM_CPUS_PER_TASK'])
       
      # get node list from slurm
      hostnames = hostlist.expand_hostlist(os.environ['SLURM_JOB_NODELIST'])
       
      # get IDs of reserved GPU
      gpu_ids = os.environ['SLURM_STEP_GPUS'].split(",")
       
      # define MASTER_ADD & MASTER_PORT
      os.environ['MASTER_ADDR'] = hostnames[0]
      os.environ['MASTER_PORT'] = str(12345 + int(min(gpu_ids))) # to avoid port conflict on the same node

      Comment: The idr_torch module recovers the values of the environment. You can reuse them in your script by calling idr_torch.rank, idr_torch.local_rank, idr_torch.size and/or idr_torch.cpus_per_task.

  2. Initialise the process group (i.e. the number of processes, the protocol of collective communications or backend, …). The backends possible are NCCL, GLOO and MPI. NCCL is recommended both for the performance and the guarantee of correct functioning on Jean Zay.
    import torch.distributed as dist
     
    from torch.nn.parallel import DistributedDataParallel as DDP
     
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=idr_torch.size,
                            rank=idr_torch.rank)
  3. Send the model on the GPU. Note that local_rank (numbering 0, 1, 2, … for each node) serves as GPU identifier.
    torch.cuda.set_device(idr_torch.local_rank)
    gpu = torch.device("cuda")
    model = model.to(gpu)
  4. Transform the model into distributed model associated with a GPU.
    ddp_model = DDP(model, device_ids=[idr_torch.local_rank])
  5. Send the micro-batches and labels to the dedicated GPU during the training.
    for (images, labels) in train_loader:
        images = images.to(gpu, non_blocking=True)
        labels = labels.to(gpu, non_blocking=True)

    Comment: Here, the option non_blocking=True is necessary if the DataLoader uses the pin memory functionality to accelerate the loading of inputs.

    The code shown below illustrates the usage of the DataLoader with a sampler adapted to data parallelism.

    batch_size = args.batch_size
    batch_size_per_gpu = batch_size // idr_torch.size
     
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()  
    optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)
     
    # Data loading code
    train_dataset = torchvision.datasets.MNIST(root=os.environ['DSDIR'],
                                                   train=True,
                                                   transform=transforms.ToTensor(),
                                                   download=False)
     
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                                        num_replicas=idr_torch.size,
                                                                        rank=idr_torch.rank,
                                                                        shuffle=True)
     
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                   batch_size=batch_size_per_gpu,
                                                   shuffle=False,
                                                   num_workers=0,
                                                   pin_memory=True,
                                                   sampler=train_sampler)

Be careful, shuffling is assigned to the DistributedSampler. Furthermore, for the seed to be different at each epoch, you need to call train_sampler.set_epoch(epoch) at the beginning of each epoch.

Saving and loading checkpoints

It is possible to put checkpoints in place during a distributed training on GPUs.

Saving

Since the model is replicated on each GPU, the saving of checkpoints can be effectuated on just one GPU to limit the writing operations. By convention, we use the GPU rank 0:

if idr_torch.rank == 0:
    torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

Consequently, the checkpoint contains information from GPU rank 0 which is saved in a format specific to distributed models.

Loading

At the beginning of the training, the loading of a checkpoint is first operated by the CPU. Then, the information is sent onto the GPU.

By default and by convention, this is sent to the memory location which was used during the saving step. In our example, only the GPU 0 will load the model in memory.

For the information to be communicated to all the GPUs, it is necessary to use the map_location argument of the torch.load function to redirect the memory storage.

In the example below, the map_location argument orders a redirection of the memory storage to the local GPU rank. Since this function is called by all the GPUs, each GPU loads the checkpoint in its own memory:

map_location = {'cuda:%d' % 0: 'cuda:%d' % idr_torch.local_rank} # remap storage from GPU 0 to local GPU
ddp_model.load_state_dict(torch.load(CHECKPOINT_PATH), map_location=map_location)) # load checkpoint

Comment: If a checkpoint is loaded just after a save, as in the PyTorch tutorial, it is necessary to call the dist.barrier() method before the loading. This call to dist.barrier() guards the synchronisation of the GPUs, guaranteeing that the saving of the checkpoint by GPU rank 0 has completely finished before the other GPUs attempt to load it.

Distributed validation

The validation step performed after each epoch or after a set of training iterations can be distributed to all GPUs engaged in model training. When data parallelism is used and the validation dataset is large, this GPU distributed validation solution appears to be the most efficient and fastest.

Here, the challenge is to calculate the metrics (loss, accuracy, etc…) per batch and per GPU, then to weighted average them on the validation dataset.

For this, it is necessary to:

  1. Load validation dataset in the same way as the training dataset but without randomized transformations such as data augmentation or shuffling (see documentation on loading PyTorch databases):
    # validation dataset loading (imagenet for example)                
    val_dataset = torchvision.datasets.ImageNet(root=root,split='val', transform=val_transform)
     
    # define distributed sampler for validation                                    
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset,
                                                                  num_replicas=idr_torch.size,
                                                                  rank=idr_torch.rank,
                                                                  shuffle=False)
     
    # define dataloader for validation                                                              
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=batch_size_per_gpu,                    
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             prefetch_factor=2)
  2. Switch from “training” mode to “validation” mode to disable some training-specific features that are costly and unnecessary here:
    • model.eval() to switch the model to “validation” mode and disable the management of dropouts, batchnorms, etc.
    • 'with torch.no_grad() to ignore gradient calculation
    • optionally, with autocast() to use AMP (mixed precision)
  3. Evaluate the model and calculate the metric by batch in the usual way (here we take the example of calculating the loss; it will be the same for other metrics):
    • 'outputs = model(val_images) followed by loss = criterion(outputs, val_labels)
  4. Weight and accumulate the metric per GPU:
    • val_loss += loss * val_images.size(0) / N with val_images.size(0) as the size of the batch and N the global size of the validation dataset. Knowing that the batches do not necessarily have the same size (the last batch is sometimes smaller), it is better here to use the value val_images.size(0).
  5. Sum the metric weighted averages over all GPUs:
    • dist.all_reduce(val_loss, op=dist.ReduceOp.SUM) to sum the metric values calculated per GPU and communicate the result to all GPUs. This operation results in inter-GPU communications.

Example after loading validation data:

model.eval()                          #  - switch into validation mode
val_loss = torch.Tensor([0.]).to(gpu) # initialize val_loss value
N = len(val_dataset)                  # get validation dataset length
 
for val_images, val_labels in val_loader:              # loop over validation batches
 
   val_images = val_images.to(gpu, non_blocking=True)  # transfer images and labels to GPUs
   val_labels = val_labels.to(gpu, non_blocking=True) 
 
    with torch.no_grad():                          # deactivate gradient computation
        with autocast():                           # activate AMP
	    outputs = model(val_images)            # evaluate model
  	    loss = criterion(outputs, val_labels)  # compute loss
 
    val_loss += loss * val_images.size(0) / N      # cumulate weighted mean per GPU
 
dist.all_reduce(val_loss, op=dist.ReduceOp.SUM)          # sum weighted means and broadcast value to each GPU
 
model.train() # switch again into training mode

Application example

Multi-GPU and multi-node execution with "DistributedDataParallel"

An example is found on Jean Zay in $DSDIR/examples_IA/Torch_parallel/Example_DataParallelism_Pytorch-eng.ipynb; it uses the MNIST data base and a simple dense network. The example is a Notebook which allows creating an execution script.

You can also download the notebook by clicking on this link.

This should be copied in your personal space (ideally in your $WORK).

$ cp $DSDIR/examples_IA/Torch_parallel/Example_DataParallelism_PyTorch-eng.ipynb $WORK

You should then execute the Notebook from a Jean Zay front end after loading a PyTorch module (see our JupyterHub documentation for more information on how to run Jupyter Notebook).

Documentation and sources

Appendices

On Jean Zay, for a ResNet-101 model, by setting a fixed minibatch size (the global size of the batch increases with the number of GPUs involved), we obtain the following throughputs which grow with the number of GPUs involved. The NCCL communication protocol is always more efficient than GLOO. Communication between Octo-GPU appears slower than between quad-GPU.

Comparaison Gloo vs NCCL

Throughputs according to the communication backend during model parallelism in Pytorch.

For NCCL, here are the average times of a training iteration for a number of GPUs involved in the distribution. The time gaps correspond to the synchronization time between GPUs.

 Temps iteration d'apprentissage

Average times of a learning iteration during model parallelism in Pytorch.

⇐ Return to the page on distributed training