TensorFlow: Multi-GPU and multi-node data parallelism

This page explains how to distribute a neural network model implemented in a TensorFlow code using the data parallelism method.

An application example is provided at the bottom of the page so that you can access a functional implementation of the following descriptions.

Implementation of a distribution strategy

To distribute a model in TensorFlow, we define a distribution strategy by creating an instance of the tf.distribute.Strategy class. This strategy enables us to manage how the data and computations are distributed on the GPUs.

Choice of strategy


TensorFlow provides several pre-implemented strategies. In this documentation, we present only the tf.distribute.MultiWorkerMirroredStrategy. This strategy has the advantage of being generic; that is, it can be used in both multi-GPU and multi-node without performance loss, in comparison to other tested strategies.

By using the MultiWorkerMirroredStrategy, the model variables are replicated or « mirrored » on all of the detected GPUs. Each GPU processes a part of the data (mini-batch) and the collective reduction operations are used to aggregate the tensors and update the variables on each GPU at each step.

Multi-worker environment

The MultiWorkerMirroredStrategy is a multi-worker version of the tf.distribute.MirroredStrategy. It executes on multiple tasks (or workers), each one of them being assimilated to a host name and a port number. The group of workers constitutes a cluster on which is based the distribution strategy to synchronise the GPUs.

The concept of worker (and cluster) enables an execution on multiple compute nodes. Each worker can be associated to one or multiple GPUs. On Jean Zay, we advise defining one worker per GPU.

The multi-worker environment of an execution is automatically detected based on the Slurm variables defined in your submission script via the tf.distribute.cluster_resolver.SlurmClusterResolver class.

The MultiWorkerMirroredStrategy can be based on two types of inter-GPU communication protocols: gRPC or NCCL. To obtain the best performance on Jean Zay, it is advised to request the usage of the NCCL communication protocol.

Defining the strategy

Defining the MultiWorkerMirroredStrategy is done in the following lines:

# build multi-worker environment from Slurm variables
cluster_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver(port_base=12345)           
# use NCCL communication protocol
implementation = tf.distribute.experimental.CommunicationImplementation.NCCL
communication_options = tf.distribute.experimental.CommunicationOptions(implementation=implementation)
#declare distribution strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy(cluster_resolver=cluster_resolver,

Comment : On Jean Zay, you can use the port numbers between 10000 and 20000 (included).

Important: There is currently a TensorFlow limitation on declaring the strategy; it must be done before any other call to a TensorFlow operation.

Integration into the learning model

To replicate a model on multiple GPUs, it must be created in the context of strategy.scope().

The .scope() method provides a context manager which captures the TensorFlow variables and communicates them to each GPU in function of the strategy chosen. We define here the elements which create the variables related to the model: loading a registered model, defining a model, the model.compile() function, the optimiser, …

The following is an example of the declaration of a model to replicate on all of the GPUs:

# get total number of workers
n_workers = int(os.environ['SLURM_NTASKS'])
# define batch size
batch_size_per_gpu = 64
global_batch_size = batch_size_per_gpu * n_workers
# load dataset
dataset = tf.data.Dataset.from_tensor_slices(...)
# [...]
dataset = dataset.batch(global_batch_size)
# model building/compiling need to be within `strategy.scope()`
with strategy.scope():
  multi_worker_model = tf.keras.Sequential([
    tf.keras.Input(shape=(28, 28)),
    tf.keras.layers.Reshape(target_shape=(28, 28, 1)),

Comment: Using certain optimisers such as SGD requires an adjustment of the learning rate proportional to the global batch size and, therefore, the number of GPUs.

Distributed training of a tf.keras.Model model type

The model.fit() function of the Keras library automatically manages training distribution according to the chosen strategy. The training is, therefore, launched in the usual manner. For example:

multi_worker_model.fit(train_dataset, epochs=10, steps_per_epoch=100)

Comments :

  • Distribution of the input data on the different processes (data sharding) is managed automatically within the model.fit() function.
  • The evaluation step is done in distributed mode automatically by giving the evaluation dataset to the model.fit() function:
    multi_worker_model.fit(train_dataset, epochs=3, steps_per_epoch=100, validation_data=valid_dataset)

Distributed training with a custom loop

We have previously seen how to define a distribution according to the MultiWorkerMirroredStrategy and, in particular, to use the .scope() method to synchronize the variable values which were created by the workers, and to make copies of these variables for each GPU of a worker: 1 worker = 1 python code execution; 1 worker may need to work with multiple GPUs and if this is the case, it is necessary to make copies of the variables on each GPU which the worker manipulates; in the case of 1 GPU per worker, we run the Python code as many times as there are GPUs and we are not concerned about the manipulation of multiple GPUs with the same script. Once the model is created from the Keras API in the .scope(), the .fit() method enables easily launching a distributed training.

To have greater control on the model training and its evaluation, it is possible to dispense with the call to the .fit() method and to define the distributed training loop yourself. For this, you need to pay attention to the following three things:

  1. Synchronize and distribute the dataset through the workers and the worker GPUs.
  2. Be sure to train the model on a mini-batch for each GPU and combine the local gradients in order to form the global gradient.
  3. Distribute the metrics calculation evenly on each GPU and combine the local metrics to determine the global metrics.

1) To manually distribute the dataset, we can use the .experimental_distribute_dataset(dataset) method which takes tf.data.Dataset in argument and divides the dataset equally for each worker. We will begin with defining the tf.data.Dataset.

Two important operations on such a dataset are shuffling the dataset and dividing it into batches.

  1. The division into batches is necessary because we first need to divide the dataset into what are called « global batches » and then, during the dataset distribution, these global batches will be divided into as many mini-batches as there are workers.
  2. Shuffling is strongly recommended for the training of a neural network but you must be careful because 2 workers could have two different mixes of the dataset. For this reason, it is imperative to set a global seed to ensure that the same mixes are effected on the workers.
  3. For example, a dataset containing 12 objects, each described by 5 values, will have a tensor of shape (5,) as basic element. If we divide it into batches of 3 objects, the basic element of this new dataset will be a tensor of shape (3,5) and this dataset will contain a total of 4 of these tensors.
  4. Consequently, .shuffle().batch() and .batch().shuffle() do not produce the same effect and the recommended operation is to apply a .shuffle().batch(). The second operation mixes the order of the batches but the interior content of the batches remains identical from one epoch to another.
  5. To facilitate the correct calculation of the gradients and metrics, it is recommended to delete the last global batch if it is smaller than the global_batch_size: To do this, use the optional drop_remainder argument of batch().
# Create a tf.data.Dataset
train_dataset = ...
# Set a global seed
# Shuffle tf.data.Dataset
train_dataset = train_dataset.shuffle(...)
# Batching using `global_batch_size`
train_dataset = train_dataset.batch(global_batch_size, drop_remainder=True)

If, however, the batching is done without the drop_remainder, you can still delete the last batch manually.

# Remove the last global batch if it is smaller than the others
# We suppose that a global batch is a tuple of (inputs,labels)
train_dataset = train_dataset.filter(lambda x, y : len(x) == global_batch_size)

We can finally obtain our tf.distribute.DistributedDataset. During the iteration of this dataset, a worker will recuperate a tf.distribute.DistributedValues containing as many mini-batches as there are GPUs; example :

# Suppose, we have 2 workers with 3 GPUs each
# Each global batch is separated into 6 mini-batches
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
# On one worker, the distributed_batch contains 3 mini-batches from the same global batch
# On the other worker, the distributed_batch contains the 3 remaining mini-batches from the same global batch
distributed_batch = next(iter(train_dist_dataset))

In this context, next is a synchronous operation: If 1 worker does not make this call, the other workers will be blocked.

2) Let us continue now to the training loop. We will consider the particular case recommended on Jean Zay of 1 GPU per worker. Here we define the train_epoch() function which will launch a training on an epoch:

def train_epoch(): # ------------------------------------------------------------------- train 1 epoch
  def step_fn(mini_batch): # ------------------------------------------------------------ train 1 step
    x, y = mini_batch
    with tf.GradientTape() as tape:
      # compute predictions
      predictions = model(x, training=True)
      # compute the loss of each input in the mini-batch without reduction
      reduction = tf.keras.losses.Reduction.NONE
      losses_mini_batch_per_gpu = tf.keras.losses.SparseCategoricalCrossentropy(reduction=reduction)(
                                                                                        y, predictions)
      # sum all the individual losses and divide by `global_batch_size`
      loss_mini_batch_per_gpu = tf.nn.compute_average_loss(losses_mini_batch_per_gpu,
      # compute the gradient
      grads = tape.gradient(loss_mini_batch_per_gpu, model.trainable_variables)
      # inside MultiWorkerMirroredStrategy, `grads` will be summed across all GPUs first
      # before updating the parameters
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      return loss_mini_batch_per_gpu
  # loss computed on the whole dataset
  total_loss_gpu = 0.0
  train_n_batches = 0
  for distributed_batch in train_dist_dataset:
    train_n_batches += 1
    total_loss_gpu += strategy.run(step_fn, args=(distributed_batch,))
  # we get the global loss across the train dataset for 1 GPU
  total_loss_gpu /= float(train_n_batches)
  # `strategy.reduce()` will sum the given value across GPUs and return result on current device
  return strategy.reduce(tf.distribute.ReduceOp.SUM, total_loss_gpu, axis=None)


  1. The @tf.function decorator is a recommended optimization when we make calls to strategy.run().
  2. In step_fn(), we use 1 GPU to calculate the loss of each sample of 1 mini-batch. Then it is important to sum the losses and to divide the sum by the size of the global batch, and not by the size of the mini-batch, since tf will automatically calculate the global gradient by summing the local gradients.
  3. If there are 2 or more GPUs per worker, the distributed_batch contains multiple mini-batches and strategy.run() enables distributing the mini-batches on the GPUs.
  4. total_loss_gpu does not contain the calculated loss on the portion of the dataset attributed to the GPU; subsequently, the .reduce() enables obtaining the loss on 1 epoch.

3) Concerning the metrics calculation part during the model evaluation: It strongly resembles the training by removing the part on the gradient calculation and the weights updating.

Configuration of the computing environment

Jean Zay modules

To obtain optimal performance with a distributed execution on Jean Zay, you need to load one of the following modules:

  • tensorflow-gpu/py3/2.4.0-noMKL
  • tensorflow-gpu/py3/2.5.0+nccl-2.8.3
  • Any module version ≥ 2.6.0

Warning: In the other Jean Zay environments, the TensorFlow library was compiled with certain options which can cause significant performance losses.

Configuration of the Slurm reservation

A correct Slurm reservation contains as many tasks as there are workers because the Python script must be executed on each worker. We join one worker to each GPU so there are as many Slurm tasks as there are GPUs.

Important: By default, TensorFlow attempts to use the HTTP protocol. To prevent this, it is necessary to deactivate the Jean Zay HTTP proxy by deleting the http_proxy, https_proxy, HTTP_PROXY and HTTPS_PROXY environment variables.

The following is an example of a reservation with 2 four-GPU nodes:

#SBATCH --job-name=tf_distributed     
#SBATCH --nodes=2 #------------------------- number of nodes
#SBATCH --ntasks=8 #------------------------ number of tasks / workers
#SBATCH --gres=gpu:4 #---------------------- number of GPUs per node
#SBATCH --cpus-per-task=10           
#SBATCH --hint=nomultithread          
#SBATCH --time=01:00:00              
#SBATCH --output=tf_distributed%j.out
#SBATCH --error=tf_distributed%j.err  
# deactivate the HTTP proxy
unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY
# activate the environment
module purge
module load tensorflow-gpu/py3/2.6.0
srun python3 -u script.py

Application example

Multi-GPU and Multi-node execution with the ''MultiWorkerMirroredStrategy''

An example in Notebook form is found on Jean Zay in $DSDIR/examples_IA/Tensorflow_parallel/Example_DataParallelism_TensorFlow.ipynb. You can also download it by clicking on this link.

The example is a Notebook in which the trainings presented above are implemented and executed on 1, 2, 4 and 8 GPUs. The examples are based on the ResNet50 model and the CIFAR-10 data base.

You must first recover the Notebook in your personal space (for example, in your $WORK):

$ cp $DSDIR/examples_IA/Tensorflow_parallel/Example_DataParallelism_TensorFlow.ipynb $WORK

You can then execute the Notebook from a Jean Zay front end by first loading a TensorFlow module (see our JupyterHub documentation for more information on how to run Jupyter Notebook).

Sources and documentation