# PyTorch Lightning : Data Parallelism multi-GPU et multi-nœuds
## Mise en pratique

*Notebook rédigé par l'équipe assistance IA de l'IDRIS*

Versions :
* mai 2022, création
* avril 2023, correctifs pour jupyterhub

Ce document présente la méthode à adopter sur Jean Zay pour distribuer votre entraînement PyTorch Lightning. Il prend comme référence la [documentation pytorch lightning](https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu.html#multi-gpu-training) et illustre la [documentation IDRIS](http://www.idris.fr/jean-zay/gpu/jean-zay-gpu-lightning-multi.html).

Dans l'exemple proposé, nous entraînons un ResNet sur la base de données CIFAR. L'apprentissage s'exécute sur plusieurs GPU et plusieurs nœuds de calcul Jean Zay.

Il s'agit ici de :
* rédiger le script Python pour l'apprentissage distribué avec Lightning
* réaliser une exécution parallèle sur Jean Zay

Il est à noter que les données MNIST et le modèle utilisé dans cet exemple sont très simples. Cela permet de présenter un code court et de tester rapidement la configuration du *Data Parallelism*, mais pas de mesurer une accélération de l'apprentissage. En effet, les temps de transfert entre GPU et le temps d'initialisation des *kernels* GPU ne sont pas négligeables par rapport aux temps d'exécution.

------------------------

### Environnement de calcul

Ce notebook peut tourner sur n'importe quel noeud de Jean-Zay. On priviligiera néanmoins une frontale jupyterhub de Jean-Zay. 

Le *hostname* doit être jean-zay-srv2.

In [None]:
!hostname

Ce notebook ne nécessite pas le chargement d'un module particulier. Le script d'entraînement qu'il génère utilise `pytorch-gpu/py3/1.11.0`

In [None]:
!module list

------------------------

### Rédaction du script Python pour l'apprentissage distribué avec Lightning

Dans cette section, nous rédigeons le script Python d'entraînement dans le fichier 'LitResNet.py'.

* Chargement des librairies :

In [None]:
%%writefile LitResNet.py
# Base
import os
import sys
from time import time
from pathlib import Path
from argparse import ArgumentParser
# Torch 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
# Pytorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler, PyTorchProfiler
# IDRIS
import idr_torch 

* Création du modèle Lightning :

In [None]:
%%writefile -a LitResNet.py
# Module Lighning -> 6 step
# Computations (init).
# Train Loop (training_step)
# Validation Loop (validation_step)
# Test Loop (test_step)
# Prediction Loop (predict_step)
# Optimizers and LR Schedulers (configure_optimizers)
class LitResNetClassifier(pl.LightningModule):
    def __init__(self, num_classes, 
                 resnet_version,
                 optimizer='adam', 
                 lr=1e-3, 
                 batch_size=16):
        
        super().__init__()

        self.__dict__.update(locals())
        resnets = {
            18: models.resnet18, 34: models.resnet34,
            50: models.resnet50, 101: models.resnet101,
            152: models.resnet152
        }
        
        optimizers = {'adam': Adam, 'sgd': SGD}
        self.optimizer = optimizers[optimizer]
        
        # instantiate loss criterion
        self.criterion = nn.BCEWithLogitsLoss() if num_classes == 2 else nn.CrossEntropyLoss()
        
        # instantiate model
        self.resnet_model = resnets[resnet_version]()
        
        # Replace old FC layer with Identity so we can train our own
        linear_size = list(self.resnet_model.children())[-1].in_features
        # replace final layer for fine tuning
        self.resnet_model.fc = nn.Linear(linear_size, num_classes)

        
    def forward(self, X):
        return self.resnet_model(X)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        if self.num_classes == 2:
            y = F.one_hot(y, num_classes=2).float()
        
        loss = self.criterion(preds, y)
        acc = (y == torch.argmax(preds,1)).type(torch.FloatTensor).mean()
        # perform logging
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True )
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True )
        return loss
    
#     def validation_step(self, batch, batch_idx):
#     x, y = batch
#     preds = self(x)
#     if self.num_classes == 2:
#         y = F.one_hot(y, num_classes=2).float()

#     loss = self.criterion(preds, y)
#     acc = (y == torch.argmax(preds,1)).type(torch.FloatTensor).mean()
#     # perform logging
#     # Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
#     self.log("validation_loss", loss, on_step=True, prog_bar=True, logger=True, sync_dist=True)
#     self.log("validation_acc", acc, on_step=True, prog_bar=True, logger=True, sync_dist=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        if self.num_classes == 2:
            y = F.one_hot(y, num_classes=2).float()
        
        loss = self.criterion(preds, y)
        acc = (y == torch.argmax(preds,1)).type(torch.FloatTensor).mean()
        # perform logging
        # Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
        self.log("test_loss", loss, on_step=True, prog_bar=True, logger=True, sync_dist=True)
        self.log("test_acc", acc, on_step=True, prog_bar=True, logger=True, sync_dist=True)
        
    def configure_optimizers(self):
        return self.optimizer(self.parameters(), lr=self.lr)

* Création du DataModule :

In [None]:
%%writefile -a LitResNet.py    
class Cifar10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir=os.environ['DSDIR']+'/CIFAR-10-images/', batch_size=32):
        super().__init__()
        self.train_path = str(data_dir)+'/train'
        #self.test_val = str(data_dir)+'/test'
        self.test_path = str(data_dir)+'/test'
        
        self.batch_size = batch_size
        self.transform = transforms.Compose([
                            transforms.Resize((128,128)),
                            transforms.ToTensor()# convert the PIL Image to a tensor
                            ])

    def setup(self, stage = None):
        self.img_train = ImageFolder(self.train_path, transform=self.transform)
        #self.img_val = ImageFolder(self.test_val, transform=self.transform)
        self.img_test = ImageFolder(self.test_path, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.img_train, 
                          batch_size=self.batch_size, 
                          shuffle=False,
                          num_workers=10,
                          pin_memory=True,
                          persistent_workers=True,
                          prefetch_factor=2,
                         )

    # def val_dataloader(self):
    #     return DataLoader(self.img_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.img_test,                          
                          batch_size=self.batch_size, 
                          shuffle=True,
                          num_workers=10,
                          pin_memory=True,)

    # def predict_dataloader(self):
    #     return DataLoader(self.img_predict, batch_size=self.batch_size)
    # def teardown(self, stage: Optional[str] = None):
    #     # Used to clean-up when the run is finished

* Configuration du Trainer et lancement de l'apprentissage :

In [None]:
%%writefile -a LitResNet.py  
# Call exemple : python3 LitResNet.py 50 10 2 $DSDIR/CIFAR-10-images
# -> Train a Lightning Resnet50, 10 classes, for 2 epochs, from DSDIR CIFAR10 dataset
if __name__ == "__main__":
    
    parser = ArgumentParser()
    # Required arguments :
    parser.add_argument("model", help="""Choose one of the predefined ResNet models provided by torchvision. e.g. 50""", type=int)
    parser.add_argument("num_classes", help="""Number of classes to be learned.""", type=int)
    parser.add_argument("num_epochs", help="""Number of Epochs to Run.""", type=int)
    parser.add_argument("dataset_path", help="""Path to training data folder.""", type=Path)
    # Optional arguments :
    parser.add_argument("-o", "--optimizer", help="""PyTorch optimizer to use. Defaults: adam.""", default='adam')
    parser.add_argument("-lr", "--learning_rate", help="Adjust learning rate of optimizer.", type=float, default=1e-3)
    parser.add_argument("-bs", "--batch_size", help="""Manually determine batch size. Defaults: 16.""",type=int, default=16)
    parser.add_argument("-s", "--save_path", help="""Path to save model trained model checkpoint.""")
    parser.add_argument("-l", "--load_path", help="""Path to load model trained model checkpoint.""")
    args = parser.parse_args()

    # Instantiate Model -> Resnet Classifier
    model = LitResNetClassifier(num_classes = args.num_classes, 
                                resnet_version = args.model,
                                optimizer = args.optimizer, 
                                lr = args.learning_rate,
                                batch_size = args.batch_size,)
    if (args.load_path != None):
        model = LitResNetClassifier.load_from_checkpoint(checkpoint_path=args.load_path,
                                                        num_classes = args.num_classes, 
                                                        resnet_version = args.model,
                                                        optimizer = args.optimizer, 
                                                        lr = args.learning_rate,
                                                        batch_size = args.batch_size,)
        
    # Instantiate Dataset -> Cifar10
    data = Cifar10DataModule(data_dir = args.dataset_path, 
                             batch_size = args.batch_size,)
    
    
    # Instantiate Profiler
    profiler = PyTorchProfiler(schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
                                record_shapes=False,
                                profile_memory=True,
                                with_stack=False,)
    
    # Instantiate lightning trainer ...
    trainer_args = {
                    'accelerator': 'gpu', 
                    'devices': int(os.environ['SLURM_GPUS_ON_NODE']), 
                    'num_nodes': int(os.environ['SLURM_NNODES']),
                    'strategy': 'ddp',
        
                    'max_epochs': args.num_epochs,
                    'amp_backend': 'native', 
                    'precision': 16, 
                    'profiler': profiler,
                   }
    trainer = pl.Trainer(**trainer_args)
    
    # ... and train model on data
    trainer.fit(model, data)
    
    # Save trained model
    save_path = (args.save_path if args.save_path is not None else './') + 'trained_model.ckpt'
    trainer.save_checkpoint(save_path)

---

### Exemple d'exécution mono-nœud multi-GPU

* Ecriture du script de soumission slurm

modifier --account=xxx@v100 avec votre compte

In [None]:
%%writefile LitResNet.slurm
#!/bin/sh
#SBATCH --job-name=LitResNet
#SBATCH --output=LitResNet.out
#SBATCH --error=LitResNet.out
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=10
#SBATCH --hint=nomultithread
#SBATCH --account=xxx@v100
#SBATCH --time=00:20:00
#SBATCH --qos=qos_gpu-dev

# go into the submission directory 
cd ${SLURM_SUBMIT_DIR}

# cleans out modules loaded in interactive and inherited by default
module purge

# loading modules
module load pytorch-gpu/py3/1.11.0

# echo of launched commands
set -x

# code execution
srun python -u LitResNet.py 50 10 10 $DSDIR/CIFAR-10-images -bs 128

* Soumission slurm

In [None]:
%%bash
# submit job
sbatch LitResNet.slurm

* Affichage de la sortie du job

In [None]:
from threading import Event
import signal
from IPython.display import clear_output

def quit(signo, _frame):
    print("Interrupted by %d, shutting down" % signo)
    exit.set()

for sig in ('TERM', 'HUP', 'INT'):
    signal.signal(getattr(signal, 'SIG'+sig), quit);

exit = Event()

sq = !squeue -u $USER -n LitResNet
tail = !tail LitResNet.out
while len(sq) >= 2 and not exit.is_set():
    
    clear_output(wait=True)
    print(sq[0], sq[1], sep='\n')
    print(*tail, sep='\n')
    
    sq = !squeue -u $USER -n LitResNet
    tail = !tail -n 10 LitResNet.out
    
    exit.wait(1)
print('\n Done!')

In [None]:
!cat LitResNet.out

---

### Exemple d'exécution multi-nœud multi-GPU

* Ecriture du script de soumission slurm

modifier --account=xxx@v100 avec votre compte

In [None]:
%%writefile LitResNet.slurm
#!/bin/sh
#SBATCH --job-name=LitResNet
#SBATCH --output=LitResNet.out
#SBATCH --error=LitResNet.out
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=10
#SBATCH --hint=nomultithread
#SBATCH --account=xxx@v100
#SBATCH --time=00:20:00
#SBATCH --qos=qos_gpu-dev

# go into the submission directory 
cd ${SLURM_SUBMIT_DIR}

# cleans out modules loaded in interactive and inherited by default
module purge

# loading modules
module load pytorch-gpu/py3/1.11.0

# echo of launched commands
set -x

# code execution
srun python -u LitResNet.py 50 10 10 $DSDIR/CIFAR-10-images -bs 128