{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Horovod + Tensorflow : Multi-GPU and multi-node data parallelism\n", "## Implementation\n", "\n", "*Notebook written by the IDRIS AI support team, November 2020*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This document presents the method to use on Jean Zay to distribute your Horovod training with TensorFlow 2, with or without Keras, depending on the **Data Parallelism** method. [Horovod documentation](https://horovod.readthedocs.io/en/stable/) is used as reference and illustrates the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-hvd-tf-multi-eng.html).\n", "In this example, we are training a convolutional neural network on the MNIST database. The learning executes on several Jean Zay GPUs and compute nodes.\n", "\n", "It consists here of: \n", "* Preparing the MNIST database \n", "* Writing the Python script for distributed learning (Data Parallelism) \n", "* Running a parallel execution on Jean Zay\n", "\n", "Note that the MNIST data and the model used in this example are very simple. This allows us to present a short code and to test the Data Parallelism configuration rapidly, but not to measure an acceleration of the training. In fact, the transfer time between GPUs together with the initialization time of the GPU kernels is sizeable in relation to the execution times." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "------------------------" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Computing environment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is intended for execution from a Jean Zay front end. The hostname must be jean-zay[1-5]." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!hostname" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A TensorFlow module must be loaded beforehand in order for this Notebook to function correctly. \n", "For example, the ``tensorflow-gpu/py3/2.3.1`` module:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!module list" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "------------------------------------" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preparation of the MNIST database" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The MNIST database is available in the DSDIR of Jean Zay." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Comment**: The DSDIR, like the SCRATCH, is a GPFS disk space which has a bandwidth of about 300 GB/s in write and in read. These are the preferred spaces for codes having intense usage for input/output operations. Your personal SCRATCH space is dedicated to your private databases; the common DSDIR space includes most of the public databases. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can test the data access with following command:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset MNIST\n", "\tNumber of datapoints: 60000\n", "\tFile location: /gpfsdswork/dataset/MNIST/mnist.npz\n", "\tSplit: Train\n" ] } ], "source": [ "import os\n", "import tensorflow as tf\n", "import numpy as np\n", "\n", "path = os.environ['DSDIR']+'/MNIST/mnist.npz'\n", "(x_train, y_train), _ = tf.keras.datasets.mnist.load_data(path)\n", "\n", "print('Dataset MNIST\\n\\tNumber of datapoints: {}\\n\\tFile location: {}\\n\\tSplit: Train'.format(len(x_train), path))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Horovod + TensorFlow 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Writing the Python script for distributed learning (Data Parallelism)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section, we write the Python training script in the ‘mnist-distributed.py’ file." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Loading libraries, creation of the data iterator, creation of the learning model (shallow convolutional neural network with 1 convolutional layer and 2 dense layers):" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting mnist-distributed.py\n" ] } ], "source": [ "%%writefile mnist-distributed.py \n", "\n", "import os\n", "import subprocess\n", "import json\n", "import datetime\n", "import argparse\n", "\n", "import tensorflow as tf\n", "import horovod.tensorflow as hvd\n", "import numpy as np\n", "\n", "def mnist_dataset(batch_size):\n", " path = os.environ['DSDIR']+'/MNIST/mnist.npz'\n", " (x_train, y_train), _ = tf.keras.datasets.mnist.load_data(path)\n", " # The `x` arrays are in uint8 and have values in the range [0, 255].\n", " # You need to convert them to float32 with values in the range [0, 1]\n", " x_train = x_train / np.float32(255)\n", " y_train = y_train.astype(np.int64)\n", " train_dataset = tf.data.Dataset.from_tensor_slices(\n", " (x_train, y_train)).repeat().shuffle(60000).batch(batch_size)\n", " return train_dataset\n", "\n", "def build_cnn_model():\n", " model = tf.keras.Sequential([\n", " tf.keras.Input(shape=(28, 28)),\n", " tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n", " tf.keras.layers.Conv2D(32, 3, activation='relu'),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(10)\n", " ])\n", "\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Defining the distributed learning function (the timers and displays are managed by process 0, which is the master process)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to mnist-distributed.py\n" ] } ], "source": [ "%%writefile -a mnist-distributed.py\n", "\n", "def main():\n", " parser = argparse.ArgumentParser()\n", " parser.add_argument('-b', '--batch-size', default=128, type =int,\n", " help='batch size. it will be divided in mini-batch for each worker')\n", " parser.add_argument('-e','--epochs', default=2, type=int, metavar='N',\n", " help='number of total epochs to run')\n", " args = parser.parse_args()\n", " \n", " hvd.init()\n", " \n", " # display info\n", " if hvd.rank() == 0:\n", " print(\">>> Training on \", hvd.size() // hvd.local_size(), \" nodes and \", hvd.size(), \" processes\")\n", " print(\"- Process {} corresponds to GPU {} of node {}\".format(hvd.rank(), hvd.local_rank(), hvd.rank() // hvd.local_size()))\n", " \n", " # Pin GPU to be used to process local rank (one GPU per process)\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " if gpus:\n", " tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')\n", " \n", " mnist_model = build_cnn_model()\n", " \n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", "\n", " # Horovod: adjust learning rate based on number of GPUs.\n", " opt = tf.optimizers.Adam(0.001 * hvd.size())\n", " \n", " # ### Get data\n", " dataset = mnist_dataset(args.batch_size)\n", " \n", " @tf.function\n", " def training_step(images, labels, first_batch):\n", " with tf.GradientTape() as tape:\n", " probs = mnist_model(images, training=True)\n", " loss_value = loss(labels, probs)\n", "\n", " # Horovod: add Horovod Distributed GradientTape.\n", " tape = hvd.DistributedGradientTape(tape)\n", "\n", " grads = tape.gradient(loss_value, mnist_model.trainable_variables)\n", " opt.apply_gradients(zip(grads, mnist_model.trainable_variables))\n", "\n", " # Horovod: broadcast initial variable states from rank 0 to all other processes.\n", " # This is necessary to ensure consistent initialization of all workers when\n", " # training is started with random weights or restored from a checkpoint.\n", " #\n", " # Note: broadcast should be done after the first gradient step to ensure optimizer\n", " # initialization.\n", " if first_batch:\n", " hvd.broadcast_variables(mnist_model.variables, root_rank=0)\n", " hvd.broadcast_variables(opt.variables(), root_rank=0)\n", "\n", " return loss_value\n", "\n", "\n", " # Horovod: adjust number of steps based on number of GPUs.\n", " start = datetime.datetime.now()\n", " for batch, (images, labels) in enumerate(dataset.take(args.epochs * 500 // hvd.size())):\n", " loss_value = training_step(images, labels, batch == 0)\n", "\n", " if batch % 100 == 0 and hvd.rank() == 0:\n", " print('Step #%d\\tLoss: %.6f' % (batch, loss_value))\n", " \n", " duration = datetime.datetime.now() - start\n", "\n", " if hvd.rank() == 0:\n", " print(' -- Trained in ' + str(duration) + ' -- ')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Defining the principal function:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to mnist-distributed.py\n" ] } ], "source": [ "%%writefile -a mnist-distributed.py\n", "\n", "if __name__ == '__main__':\n", "\n", " main()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example of mono-GPU mono-node execution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Writing the submission batch script\n", "\n", "**Remember**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify for which allocation the consumed hours should be counted by adding the option `--account=my_project@gpu` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting batch_monogpu.slurm\n" ] } ], "source": [ "%%writefile batch_monogpu.slurm\n", "#!/bin/sh\n", "#SBATCH --job-name=mnist_tensorflow_monogpu\n", "#SBATCH --output=mnist_tensorflow_monogpu.out\n", "#SBATCH --error=mnist_tensorflow_monogpu.err\n", "#SBATCH --ntasks=1\n", "#SBATCH --gres=gpu:1\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "##SBATCH --qos=qos_gpu-dev\n", "#SBATCH --time=00:10:00\n", "\n", "# go into the submission directory \n", "cd ${SLURM_SUBMIT_DIR}\n", "\n", "# cleans out modules loaded in interactive and inherited by default\n", "module purge\n", "\n", "# loading modules\n", "module load tensorflow-gpu/py3/2.3.1\n", "\n", "# echo of launched commands\n", "set -x\n", "\n", "# code execution\n", "srun python -u mnist-distributed.py --epochs 8 --batch-size 128" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Submission of the batch script and display of the output" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 1381377\n" ] } ], "source": [ "%%bash\n", "# submit job\n", "sbatch batch_monogpu.slurm" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)\n", " 1381377 gpu_p13 mnist_te ssos040 CG 0:35 1 r14i4n4\n", " Done!\n" ] } ], "source": [ "# watch Slurm queue line until the job is done\n", "# execution should take about 1 minute\n", "import time\n", "sq = !squeue -u $USER\n", "print(sq[0])\n", "while len(sq) >= 2:\n", " print(sq[1],end='\\r')\n", " time.sleep(5)\n", " sq = !squeue -u $USER\n", "print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> Training on 1 nodes and 1 processes\n", "- Process 0 corresponds to GPU 0 of node 0\n", "Step #0\tLoss: 2.318774\n", "Step #100\tLoss: 0.294576\n", "Step #200\tLoss: 0.157830\n", "Step #300\tLoss: 0.042475\n", "Step #400\tLoss: 0.067899\n", "Step #500\tLoss: 0.052315\n", "Step #600\tLoss: 0.028083\n", "Step #700\tLoss: 0.029146\n", "Step #800\tLoss: 0.030713\n", "Step #900\tLoss: 0.043855\n", "Step #1000\tLoss: 0.029083\n", "Step #1100\tLoss: 0.010532\n", "Step #1200\tLoss: 0.049265\n", "Step #1300\tLoss: 0.009070\n", "Step #1400\tLoss: 0.008195\n", "Step #1500\tLoss: 0.031665\n", "Step #1600\tLoss: 0.025499\n", "Step #1700\tLoss: 0.021718\n", "Step #1800\tLoss: 0.021872\n", "Step #1900\tLoss: 0.007493\n", "Step #2000\tLoss: 0.007308\n", "Step #2100\tLoss: 0.011244\n", "Step #2200\tLoss: 0.007231\n", "Step #2300\tLoss: 0.015845\n", "Step #2400\tLoss: 0.002200\n", "Step #2500\tLoss: 0.000758\n", "Step #2600\tLoss: 0.010867\n", "Step #2700\tLoss: 0.021032\n", "Step #2800\tLoss: 0.002063\n", "Step #2900\tLoss: 0.005934\n", "Step #3000\tLoss: 0.011248\n", "Step #3100\tLoss: 0.007928\n", "Step #3200\tLoss: 0.000729\n", "Step #3300\tLoss: 0.002969\n", "Step #3400\tLoss: 0.012642\n", "Step #3500\tLoss: 0.000213\n", "Step #3600\tLoss: 0.015799\n", "Step #3700\tLoss: 0.002759\n", "Step #3800\tLoss: 0.001522\n", "Step #3900\tLoss: 0.001440\n", " -- Trained in 0:00:10.997139 -- \n" ] } ], "source": [ "#display output \n", "%cat mnist_tensorflow_monogpu.out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example of multi-GPU mono-node execution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Writing the submission batch script\n", "\n", "**Remember**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify for which allocation the consumed hours should be counted by adding the option `--account=my_project@gpu` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html)." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting batch_mononode.slurm\n" ] } ], "source": [ "%%writefile batch_mononode.slurm\n", "#!/bin/sh\n", "#SBATCH --job-name=mnist_tensorflow_mononode\n", "#SBATCH --output=mnist_tensorflow_mononode.out\n", "#SBATCH --error=mnist_tensorflow_mononode.err\n", "#SBATCH --ntasks=4\n", "#SBATCH --gres=gpu:4\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "##SBATCH --qos=qos_gpu-dev\n", "#SBATCH --time=00:10:00\n", "\n", "# go into the submission directory \n", "cd ${SLURM_SUBMIT_DIR}\n", "\n", "# cleans out modules loaded in interactive and inherited by default\n", "module purge\n", "\n", "# loading modules\n", "module load tensorflow-gpu/py3/2.3.1\n", "\n", "# echo of launched commands\n", "set -x\n", "\n", "# code execution\n", "srun python -u mnist-distributed.py --epochs 8 --batch-size 128" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Submission of the batch script and display of the output" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 1381398\n" ] } ], "source": [ "%%bash\n", "# submit job\n", "sbatch batch_mononode.slurm" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)\n", " 1381398 gpu_p13 mnist_te ssos040 CG 0:32 1 r9i5n1\n", " Done!\n" ] } ], "source": [ "# watch Slurm queue line until the job is done\n", "# execution should take less than 1 minute\n", "import time\n", "sq = !squeue -u $USER -n mnist_tensorflow_mononode\n", "print(sq[0])\n", "while len(sq) >= 2:\n", " print(sq[1],end='\\r')\n", " time.sleep(5)\n", " sq = !squeue -u $USER -n mnist_tensorflow_mononode\n", "print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> Training on 1 nodes and 4 processes\n", "- Process 3 corresponds to GPU 3 of node 0\n", "- Process 0 corresponds to GPU 0 of node 0\n", "- Process 1 corresponds to GPU 1 of node 0\n", "- Process 2 corresponds to GPU 2 of node 0\n", "Step #0\tLoss: 2.306480\n", "Step #100\tLoss: 0.043011\n", "Step #200\tLoss: 0.053806\n", "Step #300\tLoss: 0.023688\n", "Step #400\tLoss: 0.026603\n", "Step #500\tLoss: 0.017523\n", "Step #600\tLoss: 0.009134\n", "Step #700\tLoss: 0.012096\n", "Step #800\tLoss: 0.000577\n", "Step #900\tLoss: 0.007398\n", " -- Trained in 0:00:11.279000 -- \n" ] } ], "source": [ "#display output \n", "%cat mnist_tensorflow_mononode.out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example of multi-GPU multi-node execution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Writing the submission batch script\n", "\n", "**Remember**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify for which allocation the consumed hours should be counted by adding the option `--account=my_project@gpu` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html)." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting batch_multinode.slurm\n" ] } ], "source": [ "%%writefile batch_multinode.slurm\n", "#!/bin/sh\n", "#SBATCH --job-name=mnist_tensorflow_multinode\n", "#SBATCH --output=mnist_tensorflow_multinode.out\n", "#SBATCH --error=mnist_tensorflow_multinode.err\n", "#SBATCH --nodes=3\n", "#SBATCH --ntasks-per-node=4\n", "#SBATCH --gres=gpu:4\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "##SBATCH --qos=qos_gpu-dev\n", "#SBATCH --time=00:10:00\n", "\n", "\n", "# go into the submission directory \n", "cd ${SLURM_SUBMIT_DIR}\n", "\n", "# cleans out modules loaded in interactive and inherited by default\n", "module purge\n", "\n", "# loading modules\n", "module load tensorflow-gpu/py3/2.2.0\n", "\n", "# echo of launched commands\n", "set -x\n", "\n", "# code execution\n", "srun python -u mnist-distributed.py --epochs 8 --batch-size 128" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Submission of the batch script and display of the output" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 1381428\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "sbatch: IDRIS: setting exclusive mode for the job.\n" ] } ], "source": [ "%%bash\n", "# submit job\n", "sbatch batch_multinode.slurm" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)\n", " 1381428 gpu_p13 mnist_te ssos040 R 0:40 3 r10i3n[2-4]\n", " Done!\n" ] } ], "source": [ "# watch Slurm queue line until the job is done\n", "# execution should take about 1 minute\n", "import time\n", "sq = !squeue -u $USER -n mnist_tensorflow_multinode\n", "print(sq[0])\n", "while len(sq) == 2:\n", " print(sq[1],end='\\r')\n", " time.sleep(5)\n", " sq = !squeue -u $USER -n mnist_tensorflow_multinode\n", "print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "- Process 8 corresponds to GPU 0 of node 2\n", "- Process 9 corresponds to GPU 1 of node 2\n", "- Process 11 corresponds to GPU 3 of node 2\n", "- Process 6 corresponds to GPU 2 of node 1\n", ">>> Training on 3 nodes and 12 processes\n", "- Process 10 corresponds to GPU 2 of node 2\n", "- Process 4 corresponds to GPU 0 of node 1\n", "- Process 2 corresponds to GPU 2 of node 0\n", "- Process 5 corresponds to GPU 1 of node 1\n", "- Process 3 corresponds to GPU 3 of node 0\n", "- Process 7 corresponds to GPU 3 of node 1\n", "- Process 0 corresponds to GPU 0 of node 0\n", "- Process 1 corresponds to GPU 1 of node 0\n", "Step #0\tLoss: 2.303728\n", "Step #100\tLoss: 0.039106\n", "Step #200\tLoss: 0.021328\n", "Step #300\tLoss: 0.002840\n", " -- Trained in 0:00:08.967639 -- \n" ] } ], "source": [ "# display output\n", "%cat mnist_tensorflow_multinode.out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Horovod + TensorFlow 2 with Keras" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section, we write the Python training script in the ‘mnist-distributed.py’ file.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Loading libraries, creation of the data iterator, creation of the learning model (shallow convolutional neural network with 1 convolutional layer and 2 dense layers):" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting mnist-distributed.py\n" ] } ], "source": [ "%%writefile mnist-distributed.py \n", "\n", "import os\n", "import subprocess\n", "import json\n", "import datetime\n", "import argparse\n", "\n", "import tensorflow as tf\n", "import horovod.tensorflow.keras as hvd\n", "import numpy as np\n", "\n", "def mnist_dataset(batch_size):\n", " path = os.environ['SCRATCH']+'/MNIST/mnist.npz'\n", " (x_train, y_train), _ = tf.keras.datasets.mnist.load_data(path)\n", " # The `x` arrays are in uint8 and have values in the range [0, 255].\n", " # You need to convert them to float32 with values in the range [0, 1]\n", " x_train = x_train / np.float32(255)\n", " y_train = y_train.astype(np.int64)\n", " train_dataset = tf.data.Dataset.from_tensor_slices(\n", " (x_train, y_train)).repeat().shuffle(60000).batch(batch_size)\n", " return train_dataset\n", "\n", "def build_cnn_model():\n", " model = tf.keras.Sequential([\n", " tf.keras.Input(shape=(28, 28)),\n", " tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n", " tf.keras.layers.Conv2D(32, 3, activation='relu'),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(10)\n", " ])\n", "\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Defining the distributed learning function (the timers and displays are managed by process 0, which is the master process)\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to mnist-distributed.py\n" ] } ], "source": [ "%%writefile -a mnist-distributed.py\n", "\n", "def main():\n", " parser = argparse.ArgumentParser()\n", " parser.add_argument('-b', '--batch-size', default=128, type =int,\n", " help='batch size. it will be divided in mini-batch for each worker')\n", " parser.add_argument('-e','--epochs', default=2, type=int, metavar='N',\n", " help='number of total epochs to run')\n", " args = parser.parse_args()\n", " \n", " hvd.init()\n", " \n", " # display info\n", " if hvd.rank() == 0:\n", " print(\">>> Training on \", hvd.size() // hvd.local_size(), \" nodes and \", hvd.size(), \" processes\")\n", " print(\"- Process {} corresponds to GPU {} of node {}\".format(hvd.rank(), hvd.local_rank(), hvd.rank() // hvd.local_size()))\n", " \n", " # Pin GPU to be used to process local rank (one GPU per process)\n", " gpus = tf.config.experimental.list_physical_devices('GPU')\n", " for gpu in gpus:\n", " tf.config.experimental.set_memory_growth(gpu, True)\n", " if gpus:\n", " tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')\n", "\n", " # Horovod: adjust learning rate based on number of GPUs.\n", " scaled_lr = 0.001 * hvd.size()\n", " opt = tf.optimizers.Adam(scaled_lr)\n", " \n", " # Horovod: add Horovod DistributedOptimizer.\n", " opt = hvd.DistributedOptimizer(opt)\n", " \n", " model = build_cnn_model()\n", " \n", " callbacks = [\n", " # Horovod: broadcast initial variable states from rank 0 to all other processes.\n", " # This is necessary to ensure consistent initialization of all workers when\n", " # training is started with random weights or restored from a checkpoint.\n", " hvd.callbacks.BroadcastGlobalVariablesCallback(0),\n", " ]\n", "\n", " # Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow\n", " # uses hvd.DistributedOptimizer() to compute gradients.\n", " model.compile(\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer=opt,\n", " metrics=['accuracy'],\n", " experimental_run_tf_function=False)\n", "\n", " # ### Get data\n", "\n", " multi_worker_dataset = mnist_dataset(args.batch_size)\n", "\n", " # ### Train the model using \"fit\" method\n", " start = datetime.datetime.now()\n", " # Train the model.\n", " # Horovod: adjust number of steps based on number of GPUs.\n", " model.fit(multi_worker_dataset,\n", " steps_per_epoch=500 // hvd.size(),\n", " epochs=args.epochs,\n", " callbacks=callbacks,\n", " verbose=1 if hvd.rank() == 0 else 0)\n", " duration = datetime.datetime.now() - start\n", "\n", " if hvd.rank() == 0:\n", " print(' -- Trained in ' + str(duration) + ' -- ')\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Defining the principal function:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to mnist-distributed.py\n" ] } ], "source": [ "%%writefile -a mnist-distributed.py\n", "\n", "if __name__ == '__main__':\n", "\n", " main()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example of mono-GPU mono-node execution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Writing the submission batch script\n", "\n", "**Remember**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify\n", "for which allocation the consumed hours should be counted by adding the option `--account=my_project@gpu` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html)." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting batch_monogpu.slurm\n" ] } ], "source": [ "%%writefile batch_monogpu.slurm\n", "#!/bin/sh\n", "#SBATCH --job-name=mnist_tensorflow_monogpu\n", "#SBATCH --output=mnist_tensorflow_monogpu.out\n", "#SBATCH --error=mnist_tensorflow_monogpu.err\n", "#SBATCH --ntasks=1\n", "#SBATCH --gres=gpu:1\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "##SBATCH --qos=qos_gpu-dev\n", "#SBATCH --time=00:10:00\n", "\n", "# go into the submission directory \n", "cd ${SLURM_SUBMIT_DIR}\n", "\n", "# cleans out modules loaded in interactive and inherited by default\n", "module purge\n", "\n", "# loading modules\n", "module load tensorflow-gpu/py3/2.3.1\n", "\n", "# echo of launched commands\n", "set -x\n", "\n", "# code execution\n", "srun python -u mnist-distributed.py --epochs 8 --batch-size 128" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Submission of the batch script and display of the output" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "# submit job\n", "sbatch batch_monogpu.slurm" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)\n", " 1381475 gpu_p13 mnist_te ssos040 R 0:25 1 r13i7n1\n", " Done!\n" ] } ], "source": [ "# watch Slurm queue line until the job is done\n", "# execution should take about 1 minute\n", "import time\n", "sq = !squeue -u $USER\n", "print(sq[0])\n", "while len(sq) >= 2:\n", " print(sq[1],end='\\r')\n", " time.sleep(5)\n", " sq = !squeue -u $USER\n", "print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> Training on 1 nodes and 1 processes\n", "- Process 0 corresponds to GPU 0 of node 0\n", "Epoch 1/8\n", "500/500 [==============================] - 1s 3ms/step - loss: 0.1652 - accuracy: 0.9516\n", "Epoch 2/8\n", "500/500 [==============================] - 1s 3ms/step - loss: 0.0499 - accuracy: 0.9859\n", "Epoch 3/8\n", "500/500 [==============================] - 1s 3ms/step - loss: 0.0288 - accuracy: 0.9913\n", "Epoch 4/8\n", "500/500 [==============================] - 1s 3ms/step - loss: 0.0177 - accuracy: 0.9948\n", "Epoch 5/8\n", "500/500 [==============================] - 1s 3ms/step - loss: 0.0116 - accuracy: 0.9965\n", "Epoch 6/8\n", "500/500 [==============================] - 1s 3ms/step - loss: 0.0085 - accuracy: 0.9974\n", "Epoch 7/8\n", "500/500 [==============================] - 1s 3ms/step - loss: 0.0068 - accuracy: 0.9980\n", "Epoch 8/8\n", "500/500 [==============================] - 1s 3ms/step - loss: 0.0050 - accuracy: 0.9985\n", " -- Trained in 0:00:13.589227 -- \n" ] } ], "source": [ "# display output\n", "%cat mnist_tensorflow_monogpu.out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example of multi-GPU mono-node execution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Writing the submission batch script" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Remember**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify\n", "for which allocation the consumed hours should be counted by adding the option `--account=my_project@gpu` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html)." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting batch_mononode.slurm\n" ] } ], "source": [ "%%writefile batch_mononode.slurm\n", "#!/bin/sh\n", "#SBATCH --job-name=mnist_tensorflow_mononode\n", "#SBATCH --output=mnist_tensorflow_mononode.out\n", "#SBATCH --error=mnist_tensorflow_mononode.err\n", "#SBATCH --ntasks=4\n", "#SBATCH --gres=gpu:4\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "##SBATCH --qos=qos_gpu-dev\n", "#SBATCH --time=00:10:00\n", "\n", "# go into the submission directory \n", "cd ${SLURM_SUBMIT_DIR}\n", "\n", "# cleans out modules loaded in interactive and inherited by default\n", "module purge\n", "\n", "# loading modules\n", "module load tensorflow-gpu/py3/2.3.1\n", "\n", "# echo of launched commands\n", "set -x\n", "\n", "# code execution\n", "srun python -u mnist-distributed.py --epochs 8 --batch-size 128" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Submission of the batch script and display of the output" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 1381497\n" ] } ], "source": [ "%%bash\n", "# submit job\n", "sbatch batch_mononode.slurm" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)\n", " 1381497 gpu_p13 mnist_te ssos040 R 0:13 1 r8i0n3\n", " Done!\n" ] } ], "source": [ "# watch Slurm queue line until the job is done\n", "# execution should take less than 1 minute\n", "import time\n", "sq = !squeue -u $USER -n mnist_tensorflow_mononode\n", "print(sq[0])\n", "while len(sq) >= 2:\n", " print(sq[1],end='\\r')\n", " time.sleep(5)\n", " sq = !squeue -u $USER -n mnist_tensorflow_mononode\n", "print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "- Process 1 corresponds to GPU 1 of node 0\n", "- Process 3 corresponds to GPU 3 of node 0\n", ">>> Training on 1 nodes and 4 processes\n", "- Process 0 corresponds to GPU 0 of node 0\n", "- Process 2 corresponds to GPU 2 of node 0\n", "Epoch 1/8\n", "125/125 [==============================] - 1s 5ms/step - loss: 0.1969 - accuracy: 0.9391\n", "Epoch 2/8\n", "125/125 [==============================] - 1s 5ms/step - loss: 0.0438 - accuracy: 0.9882\n", "Epoch 3/8\n", "125/125 [==============================] - 1s 5ms/step - loss: 0.0264 - accuracy: 0.9924\n", "Epoch 4/8\n", "125/125 [==============================] - 1s 5ms/step - loss: 0.0171 - accuracy: 0.9951\n", "Epoch 5/8\n", "125/125 [==============================] - 1s 5ms/step - loss: 0.0111 - accuracy: 0.9967\n", "Epoch 6/8\n", "125/125 [==============================] - 1s 5ms/step - loss: 0.0044 - accuracy: 0.9987\n", "Epoch 7/8\n", "125/125 [==============================] - 1s 5ms/step - loss: 0.0043 - accuracy: 0.9986\n", "Epoch 8/8\n", "125/125 [==============================] - 1s 5ms/step - loss: 0.0077 - accuracy: 0.9979\n", " -- Trained in 0:00:08.241701 -- \n" ] } ], "source": [ "#display output \n", "%cat mnist_tensorflow_mononode.out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example of multi-GPU multi-node execution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Writing the submission batch script\n", "\n", "**Remember**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify\n", "for which allocation the consumed hours should be counted by adding the option `--account=my_project@gpu` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html)." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting batch_multinode.slurm\n" ] } ], "source": [ "%%writefile batch_multinode.slurm\n", "#!/bin/sh\n", "#SBATCH --job-name=mnist_tensorflow_multinode\n", "#SBATCH --output=mnist_tensorflow_multinode.out\n", "#SBATCH --error=mnist_tensorflow_multinode.err\n", "#SBATCH --nodes=3\n", "#SBATCH --ntasks-per-node=4\n", "#SBATCH --gres=gpu:4\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "##SBATCH --qos=qos_gpu-dev\n", "#SBATCH --time=00:10:00\n", "\n", "# go into the submission directory \n", "cd ${SLURM_SUBMIT_DIR}\n", "\n", "# cleans out modules loaded in interactive and inherited by default\n", "module purge\n", "\n", "# loading modules\n", "module load tensorflow-gpu/py3/2.2.0\n", "\n", "# echo of launched commands\n", "set -x\n", "\n", "# code execution\n", "srun python -u mnist-distributed.py --epochs 8 --batch-size 128" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Submission of the batch script and display of the output" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 1381500\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "sbatch: IDRIS: setting exclusive mode for the job.\n" ] } ], "source": [ "%%bash\n", "# submit job\n", "sbatch batch_multinode.slurm" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)\n", " 1381500 gpu_p13 mnist_te ssos040 R 0:15 3 r8i0n3,r8i3n[0-1]\n", " Done!\n" ] } ], "source": [ "# watch Slurm queue line until the job is done\n", "# execution should take about 1 minute\n", "import time\n", "sq = !squeue -u $USER -n mnist_tensorflow_multinode\n", "print(sq[0])\n", "while len(sq) == 2:\n", " print(sq[1],end='\\r')\n", " time.sleep(5)\n", " sq = !squeue -u $USER -n mnist_tensorflow_multinode\n", "print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "- Process 2 corresponds to GPU 2 of node 0\n", "- Process 3 corresponds to GPU 3 of node 0\n", "- Process 1 corresponds to GPU 1 of node 0\n", "- Process 4 corresponds to GPU 0 of node 1\n", "- Process 11 corresponds to GPU 3 of node 2\n", ">>> Training on 3 nodes and 12 processes\n", "- Process 5 corresponds to GPU 1 of node 1\n", "- Process 8 corresponds to GPU 0 of node 2\n", "- Process 0 corresponds to GPU 0 of node 0\n", "- Process 6 corresponds to GPU 2 of node 1\n", "- Process 9 corresponds to GPU 1 of node 2\n", "- Process 7 corresponds to GPU 3 of node 1\n", "- Process 10 corresponds to GPU 2 of node 2\n", "Epoch 1/8\n", "41/41 [==============================] - 0s 10ms/step - loss: 0.6382 - accuracy: 0.8178\n", "Epoch 2/8\n", "41/41 [==============================] - 0s 10ms/step - loss: 0.0741 - accuracy: 0.9771\n", "Epoch 3/8\n", "41/41 [==============================] - 0s 10ms/step - loss: 0.0436 - accuracy: 0.9874\n", "Epoch 4/8\n", "41/41 [==============================] - 0s 10ms/step - loss: 0.0301 - accuracy: 0.9912\n", "Epoch 5/8\n", "41/41 [==============================] - 0s 10ms/step - loss: 0.0174 - accuracy: 0.9958\n", "Epoch 6/8\n", "41/41 [==============================] - 0s 10ms/step - loss: 0.0104 - accuracy: 0.9977\n", "Epoch 7/8\n", "41/41 [==============================] - 0s 10ms/step - loss: 0.0053 - accuracy: 0.9987\n", "Epoch 8/8\n", "41/41 [==============================] - 0s 10ms/step - loss: 0.0043 - accuracy: 0.9992\n", " -- Trained in 0:00:08.324547 -- \n" ] } ], "source": [ "# display output\n", "%cat mnist_tensorflow_multinode.out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }