PyTorch Lightning: Multi-GPU and Multi-node Data Parallelism

This page explains how to distribute an artificial neural model implemented in a Pytorch Lightning code, according to the method of data parallelism.

Before going further, it is necessary to have the basics concerning the usage of Pytorch Lightning. The Pytorch Lightning documentation is very complete and provides many examples.

Multi-process configuration with SLURM

For multi-nodes, it is necessary to use multi-processing managed by SLURM (execution via the SLURM command srun). For mono-node, it is possible to use torch.multiprocessing.spawn as indicated in the PyTorch documentation. However, it is also possible, and more practical,to use SLURM multi-processing in either case, mono-node or multi-node. This is what we will document on this page.

When you launch a script with the SLURM srun command, the script is automatically distributed on all the predefined tasks. For example, if we reserve four 8-GPU nodes and request 3 GPUs per node, we obtain:

 4 nodes, indexed from 0 to 3.
 3 GPUs/node, indexed from 0 to 2 on each node.
 4 x 3 = 12 processes in total, allowing the execution of **12 tasks**, with ranks from 0 to 11.

The following are examples of SLURM script headings for Jean-Zay :

  • For a reservation of N four-GPU nodes via the GPU partition, by default:
    #SBATCH --nodes=N            # total number of nodes (N to be defined)
    #SBATCH --ntasks-per-node=4  # number of tasks per node (here 4 tasks, or 1 task per GPU)
    #SBATCH --gres=gpu:4         # number of GPUs reserved per node (here 4, or all the GPUs)
    #SBATCH --cpus-per-task=10   # number of cores per task (4x10 = 40 cores, or all the cores)
    #SBATCH --hint=nomultithread

    Comment : Here, the nodes are reserved exclusively. Of particular note, this gives access to the entire memory of each node.

  • For a reservation of N eight-GPU nodes via the gpu_p2 partition:
    #SBATCH --partition=gpu_p2
    #SBATCH --nodes=N            # total number of nodes (N to be defined)
    #SBATCH --ntasks-per-node=8  # number of tasks per node (ici 8 tasks, or 1 task per GPU)
    #SBATCH --gres=gpu:8         # number of GPUs reserved per node (here 8, or all the GPUs)
    #SBATCH --cpus-per-task=3    # number of cores per task (8x3 = 24 cores, or all the cores)
    #SBATCH --hint=nomultithread

    Comment : Here, the nodes are reserved exclusively. In particular, this gives access to the entire memory of each node.

Configuration and distribution strategy of Pytorch Lightning

You can discover the majority of possible configurations and strategies (dp, ddp,ddp_spawn, ddp2, horovod, deepspeed, fairscale, etc) in the multi-gpu training documentation.

On Jean Zay, we recommend using the DDP Strategy because it’s the one which has the least restriction on Pytorch Lightning. It’s the same thing as the DistributedDataParallel provided by Pytorch but through the Lightning overlayer.

Comment: Lightning selects the nccl backend by default. This is the one with the best performance on Jean Zay but we can also use gloo and mpi.

Using this strategy is managed in the Trainer arguments. Here is an example of a possible implementation which directly recuperates the Slurm environment variables:

trainer_args = {'accelerator': 'gpu',
                'devices': int(os.environ['SLURM_GPUS_ON_NODE']),
                'num_nodes': int(os.environ['SLURM_NNODES']),
                'strategy': 'ddp'}
trainer = pl.Trainer(**trainer_args)

Note: The performance tests done on Jean Zay show that, in the best cases, the Lightning overlayer generates a time performance loss of about 5% compared to Pytorch DistributedDataParallel. We also find this amount in the reference tests done by Lightning. Lightning is intended to be a high-level interface so it is natural that facility of implementation is exchanged for high performance.

Saving and loading checkpoints


By default, Lightning automatically saves checkpoints.

In multi-gpu, Lightning assures that the saving is only done by the prinicipal process. There is no specific code to add.

trainer = Trainer(strategy="ddp")
model = MyLightningModule(hparams)
# Saves only on the main process

Comment: In order to avoid saving conflicts, Lightning advises always using the save_checkpoint() function. If you use a “homemade” function, you must remember to decorate it with rank_zero_only().

from pytorch_lightning.utilities import rank_zero_only
[ | entraînement multi-gpu]]
def homemade_save_checkpoint(path):


At the beginning of a learning, the loading of a checkpoint is first operated by the CPU, then the information is sent on the GPU. Here again, there is no specific code to add; Lightning takes care of it all.

#On importe les poids du modèle pour un nouvel apprentissage
model = LitModel.load_from_checkpoint(PATH)
#Ou on restaure l'entraînement précédent (model, epoch, step, LR schedulers, apex, etc...)
model = LitModel()
trainer = Trainer(), ckpt_path="some/path/to/my_checkpoint.ckpt")

Distributed validation

The validation step, executed after each epoch or after a set number of learning iterations, can be distributed on all the GPUs engaged in the model learning. By default, the validation is distributed on Lightning as soon as we pass to multi-GPUs. When data parallelism is used and when the total of the validation data is consequential, this solution of validation distributed on the GPUs seems to be the most effective and the most rapid.

Here, the challenge is to calculate the metrics (loss, accuracy, etc.) by batch and by GPU, followed by doing a weighted average of them on the whole validation dataset.

If you use native metrics, or based on TorchMetrics, you don't need to do anything.

class MyModel(LightningModule):
    def __init__(self):
        self.accuracy = torchmetrics.Accuracy()
    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        # log step metric
        self.accuracy(preds, y)
        self.log("train_acc_step", self.accuracy, on_epoch=True)

Inversely, if you use other metrics, you need to effectuate a synchronization during the logging:

def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    # Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
    self.log("validation_loss", loss, on_step=True, on_epoch=True, sync_dist=True)
def test_step(self, batch, batch_idx):
    x, y = batch
    tensors = self(x)
    return tensors
def test_epoch_end(self, outputs):
    mean = torch.mean(self.all_gather(outputs))
    # When logging only on rank 0, don't forget to add
    # ``rank_zero_only=True`` to avoid deadlocks on synchronization.
    if self.trainer.is_global_zero:
        self.log("my_reduced_metric", mean, rank_zero_only=True)

Application example

Multi-GPU, multi-node execution with "ddp" on Pytorch Lightning

An example is found in $DSDIR/examples_IA/Torch_parallel/Example_DataParallelism_PyTorch_Lightning.ipynb on Jean-Zay; it uses the CIFAR database and a Resnet50. The example is a Notebook which enables creating an execution script.

You can also download the Notebook by clicking on this link.

It should be copied into your personal space (ideally, in your $WORK).

$ cp $DSDIR/examples_IA/Torch_parallel/Example_DataParallelism_PyTorch_Lightning.ipynb $WORK

You can then open and execute the Notebook from your JupyterHub service. For how to use JupyterHub, you may consult the dedicated documentation: Jean Zay: Access to JupyterHub and Jean Zay JupyterHub documentation

Documentation and sources