PyTorch Lightning : Parallélisme de données multi-GPU et multi-nœuds

Cette page explique comment distribuer un modèle neuronal artificiel implémenté dans un code Pytorch Lightning, selon la méthode du parallélisme de données.

Avant d'aller plus loin, il est obligatoire d'avoir les bases concernant l'utilisation de Pytorch Lightning. La documentation de Pytorch Lightning est très complète et propose beaucoup d'exemples.

Configuration multi-process avec SLURM

Pour le multi-nœuds, il est nécessaire d'utiliser le multi-processing géré par SLURM (exécution via la commande SLURM srun). Pour le mono-nœud il est possible d'utiliser des stratégies mono-nœud comme dp. Cependant, il est possible et plus pratique d'utiliser le multi-processing SLURM dans tous les cas (mono-nœud ou multi-nœuds). C'est ce que nous documentons dans cette page.

Dans SLURM, lorsqu'on lance un script avec la commande srun, le script est automatiquement distribué sur toutes les tâches prédéfinies. Par exemple, si nous réservons 4 nœuds quadri-GPU en demandant 4 GPU par nœud, nous obtenons :

  4 nœuds, indexés de 0 à 3
  4 GPU/nœud indexés de 0 à 3 sur chaque nœud,
  4 x 4 = 16 processus au total permettant d'exécuter 16 tâches avec les rangs de 0 à 15

Voici des exemples d'en-têtes de script SLURM pour Jean-Zay :

  • pour une réservation de N nœuds quadri-GPU via la partition gpu par défaut :
    #SBATCH --nodes=N            # nombre total de nœuds (N à définir)
    #SBATCH --ntasks-per-node=4  # nombre de tache par noeud (ici 4 taches soit 1 tache par GPU)
    #SBATCH --gres=gpu:4         # nombre de GPU réservés par nœud (ici 4 soit tous les GPU)
    #SBATCH --cpus-per-task=10   # nombre de cœurs par tache (donc 4x10 = 40 cœurs soit tous les cœurs)
    #SBATCH --hint=nomultithread


  • pour une réservation de N nœuds octo-GPU via la partition gpu_p2 :
    #SBATCH --partition=gpu_p2
    #SBATCH --nodes=N            # nombre total de nœuds (N à définir)
    #SBATCH --ntasks-per-node=8  # nombre de tache par nœud (ici 8 taches soit 1 tache par GPU)
    #SBATCH --gres=gpu:8         # nombre de GPU réservés par nœud (ici 8 soit tous les GPU)
    #SBATCH --cpus-per-task=3    # nombre de cœurs par tache (donc 8x3 = 24 cœurs soit tous les cœurs)
    #SBATCH --hint=nomultithread


Configuration et stratégie de distribution de Pytorch Lightning

La documentation entraînement multi-gpu permet de découvrir la majorité des configurations et stratégies possibles (dp, ddp,ddp_spawn, ddp2, horovod, deepspeed, fairscale, etc).

Sur Jean Zay, on recommande d'utiliser la stratégie ddp car c'est celle qui a le moins de restriction sur Pytorch Lightning. C'est la même chose que le DistributedDataParallel fourni par Pytorch mais à travers la surcouche Lightning.

Remarque : Lightning sélectionne par défaut le backend nccl. C'est ce qui est le plus performant sur Jean Zay mais on peut aussi utiliser gloo et mpi.

L'utilisation de cette stratégie se gère dans les arguments du Trainer. Voici un exemple d'implémentation possible qui récupère directement les variables d'environnements Slurm :

trainer_args = {'accelerator': 'gpu', 
                'devices': int(os.environ['SLURM_GPUS_ON_NODE']), 
                'num_nodes': int(os.environ['SLURM_NNODES']),
                'strategy': 'ddp'} 
trainer = pl.Trainer(**trainer_args)
trainer.fit(model)

Note : Les quelques essais de performance réalisés sur Jean Zay, montre que la surcouche Lightning génère, dans le meilleur des cas, une perte de performance temporelle de l'ordre de 5% par rapport à DistributedDataParallel de Pytorch. On retrouve cet ordre de grandeur sur les tests de référence réalisés par Lightning. Lightning se veut être une interface de haut-niveau, il est donc normal d'échanger de la performance contre une facilité d'implémentation.

Sauvegarde et chargement de checkpoints

Sauvegarde

Par défaut Lightning, réalise automatiquement la sauvegarde de checkpoints.

En multi-gpu, Lightning s'assure que la sauvegarde n'est réalisée que par le processus principal. Il n'y a pas de code spécifique à ajouter.

trainer = Trainer(strategy="ddp")
model = MyLightningModule(hparams)
trainer.fit(model)
# Saves only on the main process
trainer.save_checkpoint("example.ckpt")

Remarque : Afin d'éviter des conflits de sauvegarde, Lightning préconise de toujours utiliser la fonction save_checkpoint(). Si vous utiliser une fonction “maison”, il faut penser à décorer celle-ci avec rank_zero_only().

from pytorch_lightning.utilities import rank_zero_only
 
@rank_zero_only
def homemade_save_checkpoint(path):
    ...
 
homemade_save_checkpoint("example.ckpt")

Chargement

Au début d'un apprentissage, le chargement d'un checkpoint est d'abord opéré par le CPU, puis l'information est envoyée sur le GPU. Là encore pas de code spécifique à ajouter, Lightning s'occupe de tout.

#On importe les poids du modèle pour un nouvel apprentissage 
model = LitModel.load_from_checkpoint(PATH)
 
#Ou on restaure l'entraînement précédent (model, epoch, step, LR schedulers, apex, etc...)
model = LitModel()
trainer = Trainer()
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")

Validation distribuée

L'étape de validation exécutée après chaque epoch ou après un nombre fixé d'itérations d'apprentissage peut se distribuer sur tous les GPU engagés dans l'apprentissage du modèle. La validation est par défaut distribuée sur Lightning dès que l'on passe en multi-GPU. Lorsque le parallélisme de données est utilisé et que l'ensemble des données de validation est conséquent, cette solution de validation distribuée sur les GPU semble être la plus efficace et la plus rapide.

Ici, l'enjeu est de calculer les métriques (loss, accuracy, etc…) par batch et par GPU, puis de les pondérer et de les moyenner sur l'ensemble des données de validation.

Si vous utilisez des métriques natives ou basées sur Torchmetric, il n'y a rien à faire.

class MyModel(LightningModule):
    def __init__(self):
        ...
        self.accuracy = torchmetrics.Accuracy()
 
    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        # log step metric
        self.accuracy(preds, y)
        self.log("train_acc_step", self.accuracy, on_epoch=True)
        ...

Dans le cas inverse il faut réaliser une synchronisation lors du logging :

def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    # Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
    self.log("validation_loss", loss, on_step=True, on_epoch=True, sync_dist=True)
def test_step(self, batch, batch_idx):
    x, y = batch
    tensors = self(x)
    return tensors
 
 
def test_epoch_end(self, outputs):
    mean = torch.mean(self.all_gather(outputs))
 
    # When logging only on rank 0, don't forget to add
    # ``rank_zero_only=True`` to avoid deadlocks on synchronization.
    if self.trainer.is_global_zero:
        self.log("my_reduced_metric", mean, rank_zero_only=True)

Exemple d'application

Exécution multi-GPU, multi-nœuds avec "ddp" sur Pytorch Lightning

Un exemple se trouve dans $DSDIR/examples_IA/Torch_parallel/Example_DataParallelism_PyTorch_Lightning.ipynb sur Jean-Zay, il utilise la base de données CIFAR et un Resnet50. L'exemple est un Notebook qui permet de créer un script d'exécution.

Vous pouvez aussi télécharger le notebook en cliquant sur ce lien.

Il est à copier sur votre espace personnel (idéalement sur votre $WORK).

$ cp $DSDIR/examples_IA/Torch_parallel/Example_DataParallelism_PyTorch_Lightning.ipynb $WORK

Vous pouvez ensuite ouvrir et exécuter le Notebook depuis notre service jupyterhub. Pour l'utilisation de jupyterhub, vous pouvez consulter les documentations dédiées : Jean Zay : Accès à JupyterHub et documentation JupyterHub Jean-Zay

Documentation et sources