PyTorch: Multi-GPU model parallelism

The methodology presented on this page shows how to concretely adapt a model which is too large for use on a single GPU. It illustrates the concepts presented on the main page: Jean Zay: Multi-GPU and multi-node distribution for training a TensorFlow or PyTorch model.

The procedure involves the following steps:

  • Adaptation of the model
  • Adaptation of the training loop
  • Configuration of the Slurm computing environment

The steps concerning the creation of the DataLoader and the saving/loading of the model are not modified, in comparison to a model deployed on a single GPU.

This document presents only the changes to make when parallelising a model. A complete demonstration is found in a notebook which is downloadable here or recoverable on Jean Zay in the DSDIR directory: /examples_IA/Torch_parallel/demo_model_parallelism_pytorch-eng.ipynb. To copy it into your $WORK space:

$ cp $DSDIR/examples_IA/Torch_parallel/demo_model_parallelism_pytorch.ipynb $WORK

This Jupyter Notebook is run from a front end, after having loaded a PyTorch module (see our JupyterHub documentation for more information on how to run Jupyter Notebook).

The principal source of information is the official PyTorch documentation: Single-Machine Model Parallel Best Practices.

Model adaptation

To illustrate the methodology, a resnet model is distributed on two GPUs (we are using the example proposed in the PyTorch documentation). For other applications having model architectures which are different from the CNNs (such as transformers or multi-model architectures), the strategy for splitting the model into sequences could be different.

Two methods are proposed:

  • Simple distribution of the model - In this case, only one GPU functions at a time.This results in a longer duration of training cycles due to inter-GPU communications (during theforward and backward propagation phases).

 Distribution simple de modèle

  • Pipeline distribution of the model, with splitting of the data batches, which enables the GPUs to work concurrently. As a result, the execution time is reduced compared to the equivalent in mono-GPU.

 Distribution de modèle avec pipeline

Adjustments to the model

The idea is to divide the model layers between two GPUs (here named dev0 and dev1).

The forward function enables communication between the two GPUs.

from torchvision.models.resnet import ResNet, Bottleneck
 
num_classes = 1000
 
class ParallelResnet(ResNet):
    def __init__(self, dev0, dev1, *args, **kwargs):
        super(ParallelResnet, self).__init__(
            Bottleneck, [3, 4, 23, 3], num_classes=num_classes, *args, **kwargs)
        # dev0 and dev1 each point to a GPU device (usually gpu:0 and gpu:1)
        self.dev0 = dev0
        self.dev1 = dev1
 
        # splits the model in two consecutive sequences : seq0 and seq1 
        self.seq0 = nn.Sequential(
            self.conv1,
            self.bn1,
            self.relu,
            self.maxpool,
            self.layer1,
            self.layer2
        ).to(self.dev0)  # sends the first sequence of the model to the first GPU
 
        self.seq1 = nn.Sequential(
            self.layer3,
            self.layer4,
            self.avgpool,
        ).to(self.dev1)  # sends the second sequence of the model to the second GPU
 
        self.fc.to(self.dev1)  # last layer is on the second GPU
 
    def forward(self, x):
        x= self.seq0(x)     # apply first sequence of the model on input x
        x= x.to(self.dev1)  # send the intermediary result to the second GPU
        x = self.seq1(x)    # apply second sequence of the model to x
        return self.fc(x.view(x.size(0), -1))  

The pipeline version proceeds from the preceding model (“parallel” version) and then divides the data batches so that the two GPUs can function quasi-concurrently. The optimal number of divisions depends on the context (model, batch size) and must be estimated for each case.

The best choice for resnet50 division, via an associated benchmark, is found in Single-Machine Model Parallel Best Practices / speed-up-by-pipelining-inputs.

Version implementing the pipeline concept:

class PipelinedResnet(ResNet):
    def __init__(self, dev0, dev1, split_size=8, *args, **kwargs):
        super(PipelinedResnet, self).__init__(
            Bottleneck, [3, 4, 23, 3], num_classes=num_classes, *args, **kwargs)
        # dev0 and dev1 each point to a GPU device (usually gpu:0 and gpu:1)
        self.dev0 = dev0
        self.dev1 = dev1
        self.split_size = split_size
 
        # splits the model in two consecutive sequences : seq0 and seq1 
        self.seq0 = nn.Sequential(
            self.conv1,
            self.bn1,
            self.relu,
            self.maxpool,
            self.layer1,
            self.layer2
        ).to(self.dev0)  # sends the first sequence of the model to the first GPU
 
        self.seq1 = nn.Sequential(
            self.layer3,
            self.layer4,
            self.avgpool,
        ).to(self.dev1)  # sends the second sequence of the model to the second GPU
 
        self.fc.to(self.dev1)  # last layer is on the second GPU
 
    def forward(self, x):
        # split setup for x, containing a batch of (image, label) as a tensor
        splits = iter(x.split(self.split_size, dim=0))
        s_next = next(splits)
        # initialisation: 
        # - first mini batch goes through seq0 (on dev0)
        # - the output is sent to dev1
        s_prev = self.seq0(s_next).to(self.dev1)
        ret = []
 
        for s_next in splits:
            # A. s_prev runs on dev1
            s_prev = self.seq1(s_prev)
            ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))
 
            # B. s_next runs on dev0, which can run concurrently with A
            s_prev = self.seq0(s_next).to(self.dev1)
 
        s_prev = self.seq1(s_prev)
        ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))
 
        return torch.cat(ret)

In summary, the forward function is modified by adding a pipeline; the change begins with the split_size parameter passed as argument when creating the model.

Adaptation of the training loop

Creation of the model

The model is loaded into GPU memory during its creation. Therefore, you should not add .to(device) afterwards.

In the case of a mono-task job, PyTorch will always number the GPUs starting from 0 (even if nvidia-smi output ranks 2 and 3 on selected graphics cards nvidia-smi). We can, therefore, enter their fixed identifier:

mp_model = PipelinedResnet(dev0='cuda:0', dev1='cuda:1')

Extract of the training loop

def train(model, optimizer, criterion, train_loader, batch_size, dev0, dev1):
    model.train()
    for batch_counter, (images, labels) in enumerate(train_loader):
        # images are sent to the first GPU
        images = images.to(dev0, non_blocking=True)
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward
        outputs = model(images)
        # labels (ground truth) are sent to the GPU where the outputs of the model
        # reside, which in this case is the second GPU 
        labels = labels.to(outputs.device, non_blocking=True)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        # backward + optimize only if in training phase
        loss.backward()
        optimizer.step()

The inputs (images) are loaded in the first GPU (cuda:0); the outputs (predicted values) are found on the second GPU (cuda:1, set in the model instantiation). Therefore, it is necessary to load the labels (real values) on the same GPU as the outputs.

The non_blockin=True option is to use in association with the pin_memory=True option of the dataloader. Caution, these options generally increase the quantity of necessary RAM memory.

Configuration of the Slurm computing environment

Only one task must be instantiated for a model (or group of models) and an adequate number of GPUs which must all be on the same node. In our example, 2 GPUs:

#SBATCH --gres=gpu:2
#SBATCH --ntasks-per-node=1

The methodology presented, which relies only on the PyTorch library, is limited to mono-node multi-GPU parallelism (of 2 GPUs, 4 GPUs or 8 GPUs) and cannot be applied to a multi-node case. It is strongly recommended to associate this technique with data parallelism, as described on the page Data and model parallelism with PyTorch, in order to effectively accelerate the trainings.

If your model requires model parallelism on more than one node, we recommend that you explore the solution documented on the page Distributed Pipeline Parallelism Using RPC.

⇐ Return to the main page on distributed learning