{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch: Multi-GPU and multi-node Data Parallelism\n", "## Implementation\n", "\n", "*Notebook written by the IDRIS AI support team, November 2020*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This document presents the method to use on Jean Zay to distribute your PyTorch training depending on the **Data Parallelism** method. Pytorch documentation is used as [reference](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and illustrates the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-torch-multi-eng.html).\n", "\n", "In this example, we are training a convolutional neural network on the MNIST database. \n", "The learning executes on several Jean Zay GPUs and compute nodes.\n", "\n", "It consists here of: \n", "* Preparing the MNIST database \n", "* Writing the Python script for distributed learning (Data Parallelism) \n", "* Running a parallel execution on Jean Zay\n", "\n", "Note that the MNIST data and the model used in this example are very simple.\n", "\n", "This allows us to present a short code and to test the Data Parallelism configuration rapidly, but not to measure an acceleration of the training. In fact, the transfer time between GPUs, together with the initialization time of the GPU kernels, is sizeable in relation to the execution times." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "------------------------" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Computing environment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is intended for execution from a Jean Zay front end. The hostname must be jean-zay[1-5]." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "jean-zay2\r\n" ] } ], "source": [ "!hostname" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A PyTorch module must be loaded beforehand in order for this Notebook to function correctly. For example, the ``pytorch-gpu/py3/1.7.0`` module:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?1h\u001b=\r", "Currently Loaded Modulefiles:\u001b[m\r\n", " 1) gcc/8.3.1 4) cudnn/8.0.4.30-cuda-10.2 7) \u001b[4mopenmpi/4.0.5-cuda\u001b[0m \u001b[m\r\n", " 2) cuda/10.2 5) intel-mkl/2020.4 8) pytorch-gpu/py3/1.8.0 \u001b[m\r\n", " 3) nccl/2.8.3-1-cuda 6) magma/2.5.4-cuda \u001b[m\r\n", "\r", "\u001b[K\u001b[?1l\u001b>" ] } ], "source": [ "!module list" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creation of a ``checkpoint`` file if it doesn't exist." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mkdir: cannot create directory ‘checkpoint’: File exists\r\n" ] } ], "source": [ "!mkdir checkpoint\n", "!rm checkpoint/*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "------------------------------------" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preparation of the MNIST database" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The MNIST database is available in the DSDIR of Jean Zay." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Comment**: The DSDIR, like the SCRATCH, is a GPFS disk space of which the bandwidth is about 300 GB/s in write and in read. These are the preferred spaces for codes having intense usage for input/output operations. Your personal SCRATCH space is dedicated to your private databases and the common DSDIR space includes most of the public databases. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can test the data access with the following command:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset MNIST\n", " Number of datapoints: 60000\n", " Root location: /gpfsdswork/dataset\n", " Split: Train\n", " StandardTransform\n", "Transform: ToTensor()" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "\n", "torchvision.datasets.MNIST(root=os.environ['DSDIR'],\n", " train=True,\n", " transform=transforms.ToTensor(),\n", " download=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Writing the Python script for distributed learning (Data Parallelism)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section, we write the Python training script in the\n", "‘mnist-distributed.py’ file." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Loading libraries and defining the main function:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting mnist-distributed.py\n" ] } ], "source": [ "%%writefile mnist-distributed.py \n", "\n", "import os\n", "from datetime import datetime\n", "from time import time\n", "import argparse\n", "import torch.multiprocessing as mp\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "import torch\n", "import torch.nn as nn\n", "import torch.distributed as dist\n", "from torch.nn.parallel import DistributedDataParallel\n", "import idr_torch\n", "\n", "def main():\n", " parser = argparse.ArgumentParser()\n", " parser.add_argument('-b', '--batch-size', default=128, type =int,\n", " help='batch size. it will be divided in mini-batch for each worker')\n", " parser.add_argument('-e','--epochs', default=2, type=int, metavar='N',\n", " help='number of total epochs to run')\n", " parser.add_argument('-c','--checkpoint', default=None, type=str,\n", " help='path to checkpoint to load')\n", " args = parser.parse_args()\n", "\n", " train(args) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Creation of the learning model (shallow convolutional neural network with 2 layers):" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to mnist-distributed.py\n" ] } ], "source": [ "%%writefile -a mnist-distributed.py\n", "\n", "class ConvNet(nn.Module):\n", " def __init__(self, num_classes=10):\n", " super(ConvNet, self).__init__()\n", " self.layer1 = nn.Sequential(\n", " nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),\n", " nn.BatchNorm2d(16),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2))\n", " self.layer2 = nn.Sequential(\n", " nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),\n", " nn.BatchNorm2d(32),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2))\n", " self.fc = nn.Linear(7*7*32, num_classes)\n", "\n", " def forward(self, x):\n", " out = self.layer1(x)\n", " out = self.layer2(out)\n", " out = out.reshape(out.size(0), -1)\n", " out = self.fc(out)\n", " return out\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Defining the distributed learning function (the timers and displays are managed by process 0, which is the master process)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to mnist-distributed.py\n" ] } ], "source": [ "%%writefile -a mnist-distributed.py\n", "\n", "def train(args):\n", " \n", " # configure distribution method: define address and port of the master node and initialise communication backend (NCCL)\n", " dist.init_process_group(backend='nccl', init_method='env://', world_size=idr_torch.size, rank=idr_torch.rank)\n", " \n", " # distribute model\n", " torch.cuda.set_device(idr_torch.local_rank)\n", " gpu = torch.device(\"cuda\")\n", " model = ConvNet().to(gpu)\n", " ddp_model = DistributedDataParallel(model, device_ids=[idr_torch.local_rank])\n", " if args.checkpoint is not None:\n", " map_location = {'cuda:%d' % 0: 'cuda:%d' % idr_torch.local_rank}\n", " ddp_model.load_state_dict(torch.load(args.checkpoint, map_location=map_location))\n", " \n", " # distribute batch size (mini-batch)\n", " batch_size = args.batch_size \n", " batch_size_per_gpu = batch_size // idr_torch.size\n", " \n", " # define loss function (criterion) and optimizer\n", " criterion = nn.CrossEntropyLoss() \n", " optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)\n", "\n", " # load data with distributed sampler\n", " train_dataset = torchvision.datasets.MNIST(root=os.environ['DSDIR'],\n", " train=True,\n", " transform=transforms.ToTensor(),\n", " download=False)\n", " \n", " train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,\n", " num_replicas=idr_torch.size,\n", " rank=idr_torch.rank)\n", " \n", " train_loader = torch.utils.data.DataLoader(dataset=train_dataset,\n", " batch_size=batch_size_per_gpu,\n", " shuffle=False,\n", " num_workers=0,\n", " pin_memory=True,\n", " sampler=train_sampler)\n", "\n", " # training (timers and display handled by process 0)\n", " if idr_torch.rank == 0: start = datetime.now() \n", " total_step = len(train_loader)\n", " \n", " for epoch in range(args.epochs):\n", " if idr_torch.rank == 0: start_dataload = time()\n", " \n", " for i, (images, labels) in enumerate(train_loader):\n", " \n", " # distribution of images and labels to all GPUs\n", " images = images.to(gpu, non_blocking=True)\n", " labels = labels.to(gpu, non_blocking=True) \n", " \n", " if idr_torch.rank == 0: stop_dataload = time()\n", "\n", " if idr_torch.rank == 0: start_training = time()\n", " \n", " # forward pass\n", " outputs = ddp_model(images)\n", " loss = criterion(outputs, labels)\n", "\n", " # backward and optimize\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " if idr_torch.rank == 0: stop_training = time() \n", " if (i + 1) % 200 == 0 and idr_torch.rank == 0:\n", " print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Time data load: {:.3f}ms, Time training: {:.3f}ms'.format(epoch + 1, args.epochs,\n", " i + 1, total_step, loss.item(), (stop_dataload - start_dataload)*1000,\n", " (stop_training - start_training)*1000))\n", " if idr_torch.rank == 0: start_dataload = time()\n", " \n", " #Save checkpoint at every end of epoch\n", " if idr_torch.rank == 0:\n", " torch.save(ddp_model.state_dict(), './checkpoint/{}GPU_{}epoch.checkpoint'.format(idr_torch.size, epoch+1))\n", "\n", " if idr_torch.rank == 0:\n", " print(\">>> Training complete in: \" + str(datetime.now() - start))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Defining the principal function:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to mnist-distributed.py\n" ] } ], "source": [ "%%writefile -a mnist-distributed.py\n", "\n", "if __name__ == '__main__':\n", " \n", " # get distributed configuration from Slurm environment\n", " NODE_ID = os.environ['SLURM_NODEID']\n", " MASTER_ADDR = os.environ['MASTER_ADDR']\n", " \n", " # display info\n", " if idr_torch.rank == 0:\n", " print(\">>> Training on \", len(idr_torch.hostnames), \" nodes and \", idr_torch.size, \" processes, master node is \", MASTER_ADDR)\n", " print(\"- Process {} corresponds to GPU {} of node {}\".format(idr_torch.rank, idr_torch.local_rank, NODE_ID))\n", "\n", " main()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example of mono-GPU mono-node execution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Writing the submission batch script\n", "\n", "**Remember**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify for which allocation the consumed hours should be counted by adding the option ``--account=my_project@gpu`` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting batch_monogpu.slurm\n" ] } ], "source": [ "%%writefile batch_monogpu.slurm\n", "#!/bin/sh\n", "#SBATCH --job-name=mnist_pytorch_monogpu\n", "#SBATCH --output=mnist_pytorch_monogpu.out\n", "#SBATCH --error=mnist_pytorch_monogpu.out\n", "#SBATCH --nodes=1\n", "#SBATCH --ntasks=1\n", "#SBATCH --gres=gpu:1\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "#SBATCH --time=00:10:00\n", "#SBATCH --qos=qos_gpu-dev\n", "\n", "# go into the submission directory \n", "cd ${SLURM_SUBMIT_DIR}\n", "\n", "# cleans out modules loaded in interactive and inherited by default\n", "module purge\n", "\n", "# loading modules\n", "module load pytorch-gpu/py3/1.7.0\n", "\n", "# echo of launched commands\n", "set -x\n", "\n", "# code execution\n", "srun python -u mnist-distributed.py --epochs 8 --batch-size 128" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Submission of the batch script and display of the output" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 210595\n" ] } ], "source": [ "%%bash\n", "# submit job\n", "sbatch batch_monogpu.slurm" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON) \n", " 210595 gpu_p13 mnist_py ssos040 R 0:51 1 r10i2n5 \n", " Done!\n" ] } ], "source": [ "# watch Slurm queue line until the job is done\n", "# execution should take about 1 minute\n", "import time\n", "sq = !squeue -u $USER -n mnist_pytorch_monogpu\n", "print(sq[0])\n", "while len(sq) >= 2:\n", " print(sq[1],end='\\r')\n", " time.sleep(5)\n", " sq = !squeue -u $USER -n mnist_pytorch_monogpu\n", "print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading pytorch-gpu/py3/1.7.0\r\n", " Loading requirement: gcc/8.3.1 cuda/10.2 nccl/2.6.4-1-cuda\r\n", " cudnn/7.6.5.32-cuda-10.2 intel-mkl/2020.1 magma/2.5.3-cuda\r\n", " openmpi/4.0.2-cuda\r\n", "+ srun python -u mnist-distributed.py --epochs 8 --batch-size 128\r\n", ">>> Training on 1 nodes and 1 processes, master node is r10i2n5\r\n", "- Process 0 corresponds to GPU 0 of node 0\r\n", "Epoch [1/8], Step [200/469], Loss: 1.9836, Time data load: 9.957ms, Time training: 2.359ms\r\n", "Epoch [1/8], Step [400/469], Loss: 1.8060, Time data load: 10.119ms, Time training: 2.400ms\r\n", "Epoch [2/8], Step [200/469], Loss: 1.4512, Time data load: 11.848ms, Time training: 4.598ms\r\n", "Epoch [2/8], Step [400/469], Loss: 1.3760, Time data load: 9.945ms, Time training: 2.332ms\r\n", "Epoch [3/8], Step [200/469], Loss: 1.1132, Time data load: 9.964ms, Time training: 2.336ms\r\n", "Epoch [3/8], Step [400/469], Loss: 1.1073, Time data load: 10.023ms, Time training: 2.385ms\r\n", "Epoch [4/8], Step [200/469], Loss: 0.8998, Time data load: 10.277ms, Time training: 2.374ms\r\n", "Epoch [4/8], Step [400/469], Loss: 0.9311, Time data load: 9.972ms, Time training: 2.671ms\r\n", "Epoch [5/8], Step [200/469], Loss: 0.7595, Time data load: 10.046ms, Time training: 2.379ms\r\n", "Epoch [5/8], Step [400/469], Loss: 0.8081, Time data load: 10.046ms, Time training: 2.344ms\r\n", "Epoch [6/8], Step [200/469], Loss: 0.6610, Time data load: 9.960ms, Time training: 2.351ms\r\n", "Epoch [6/8], Step [400/469], Loss: 0.7174, Time data load: 9.982ms, Time training: 2.374ms\r\n", "Epoch [7/8], Step [200/469], Loss: 0.5885, Time data load: 10.031ms, Time training: 2.374ms\r\n", "Epoch [7/8], Step [400/469], Loss: 0.6481, Time data load: 9.977ms, Time training: 2.359ms\r\n", "Epoch [8/8], Step [200/469], Loss: 0.5330, Time data load: 9.987ms, Time training: 2.358ms\r\n", "Epoch [8/8], Step [400/469], Loss: 0.5936, Time data load: 9.966ms, Time training: 2.361ms\r\n", ">>> Training complete in: 0:00:47.944993\r\n" ] } ], "source": [ "# display output\n", "%cat mnist_pytorch_monogpu.out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example of multi-GPU mono-node execution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Writing the submission batch script\n", "\n", "**Remember**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify for which allocation the consumed hours should be counted by adding the option ``--account=my_project@gpu`` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html)." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting batch_mononode.slurm\n" ] } ], "source": [ "%%writefile batch_mononode.slurm\n", "#!/bin/sh\n", "#SBATCH --job-name=mnist_pytorch_mononode\n", "#SBATCH --output=mnist_pytorch_mononode.out\n", "#SBATCH --error=mnist_pytorch_mononode.out\n", "#SBATCH --nodes=1\n", "#SBATCH --ntasks=4\n", "#SBATCH --gres=gpu:4\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "#SBATCH --time=00:10:00\n", "#SBATCH --qos=qos_gpu-dev\n", "\n", "# go into the submission directory \n", "cd ${SLURM_SUBMIT_DIR}\n", "\n", "# cleans out modules loaded in interactive and inherited by default\n", "module purge\n", "\n", "# loading modules\n", "module load pytorch-gpu/py3/1.7.0\n", "\n", "# echo of launched commands\n", "set -x\n", "\n", "# code execution\n", "srun python -u mnist-distributed.py --epochs 8 --batch-size 128" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Submission of the batch script and display of the output" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 210599\n" ] } ], "source": [ "%%bash\n", "# submit job\n", "sbatch batch_mononode.slurm" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON) \n", " 210599 gpu_p13 mnist_py ssos040 R 0:23 1 r10i7n0 \n", " Done!\n" ] } ], "source": [ "# watch Slurm queue line until the job is done\n", "# execution should take less than 1 minute\n", "import time\n", "sq = !squeue -u $USER -n mnist_pytorch_mononode\n", "print(sq[0])\n", "while len(sq) >= 2:\n", " print(sq[1],end='\\r')\n", " time.sleep(5)\n", " sq = !squeue -u $USER -n mnist_pytorch_mononode\n", "print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading pytorch-gpu/py3/1.7.0\r\n", " Loading requirement: gcc/8.3.1 cuda/10.2 nccl/2.6.4-1-cuda\r\n", " cudnn/7.6.5.32-cuda-10.2 intel-mkl/2020.1 magma/2.5.3-cuda\r\n", " openmpi/4.0.2-cuda\r\n", "+ srun python -u mnist-distributed.py --epochs 8 --batch-size 128\r\n", "- Process 3 corresponds to GPU 3 of node 0\r\n", ">>> Training on 1 nodes and 4 processes, master node is r10i7n0\r\n", "- Process 0 corresponds to GPU 0 of node 0\r\n", "- Process 1 corresponds to GPU 1 of node 0\r\n", "- Process 2 corresponds to GPU 2 of node 0\r\n", "Epoch [1/8], Step [200/469], Loss: 1.9805, Time data load: 2.751ms, Time training: 2.274ms\r\n", "Epoch [1/8], Step [400/469], Loss: 1.7145, Time data load: 2.744ms, Time training: 2.271ms\r\n", "Epoch [2/8], Step [200/469], Loss: 1.4873, Time data load: 2.740ms, Time training: 2.270ms\r\n", "Epoch [2/8], Step [400/469], Loss: 1.2656, Time data load: 2.743ms, Time training: 2.273ms\r\n", "Epoch [3/8], Step [200/469], Loss: 1.1683, Time data load: 2.735ms, Time training: 2.314ms\r\n", "Epoch [3/8], Step [400/469], Loss: 1.0061, Time data load: 2.734ms, Time training: 2.271ms\r\n", "Epoch [4/8], Step [200/469], Loss: 0.9584, Time data load: 2.796ms, Time training: 2.273ms\r\n", "Epoch [4/8], Step [400/469], Loss: 0.8424, Time data load: 2.749ms, Time training: 2.277ms\r\n", "Epoch [5/8], Step [200/469], Loss: 0.8142, Time data load: 2.742ms, Time training: 2.273ms\r\n", "Epoch [5/8], Step [400/469], Loss: 0.7320, Time data load: 2.735ms, Time training: 2.270ms\r\n", "Epoch [6/8], Step [200/469], Loss: 0.7107, Time data load: 2.755ms, Time training: 2.272ms\r\n", "Epoch [6/8], Step [400/469], Loss: 0.6522, Time data load: 2.735ms, Time training: 2.270ms\r\n", "Epoch [7/8], Step [200/469], Loss: 0.6334, Time data load: 2.734ms, Time training: 2.268ms\r\n", "Epoch [7/8], Step [400/469], Loss: 0.5920, Time data load: 2.759ms, Time training: 2.272ms\r\n", "Epoch [8/8], Step [200/469], Loss: 0.5736, Time data load: 2.742ms, Time training: 2.272ms\r\n", "Epoch [8/8], Step [400/469], Loss: 0.5449, Time data load: 2.748ms, Time training: 2.270ms\r\n", ">>> Training complete in: 0:00:20.244805\r\n" ] } ], "source": [ "#display output \n", "%cat mnist_pytorch_mononode.out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example of multi-GPU multi-node execution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Writing the submission batch script\n", "\n", "**Remember**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify for which allocation the consumed hours should be counted by adding the option ``--account=my_project@gpu`` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html).\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting batch_multinode.slurm\n" ] } ], "source": [ "%%writefile batch_multinode.slurm\n", "#!/bin/sh\n", "#SBATCH --job-name=mnist_pytorch_multinode\n", "#SBATCH --output=mnist_pytorch_multinode.out\n", "#SBATCH --error=mnist_pytorch_multinode.out\n", "#SBATCH --nodes=3\n", "#SBATCH --gres=gpu:4\n", "#SBATCH --ntasks-per-node=4\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "#SBATCH --time=00:10:00\n", "#SBATCH --qos=qos_gpu-dev\n", "\n", "# go into the submission directory \n", "cd ${SLURM_SUBMIT_DIR}\n", "\n", "# cleans out modules loaded in interactive and inherited by default\n", "module purge\n", "\n", "# loading modules\n", "module load pytorch-gpu/py3/1.7.0\n", "\n", "# echo of launched commands\n", "set -x\n", "\n", "# code execution\n", "srun python -u mnist-distributed.py --epochs 8 --batch-size 128" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Submission of the batch script and display of the output" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 210604\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "sbatch: IDRIS: setting exclusive mode for the job.\n" ] } ], "source": [ "%%bash\n", "# submit job\n", "sbatch batch_multinode.slurm" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON) \n", " 210604 gpu_p13 mnist_py ssos040 R 0:26 3 r13i5n[7-8],r13i6n0 \n", " Done!\n" ] } ], "source": [ "# watch Slurm queue line until the job is done\n", "# execution should take about 1 minute\n", "import time\n", "sq = !squeue -u $USER -n mnist_pytorch_multinode\n", "print(sq[0])\n", "while len(sq) >= 2:\n", " print(sq[1],end='\\r')\n", " time.sleep(5)\n", " sq = !squeue -u $USER -n mnist_pytorch_multinode\n", "print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading pytorch-gpu/py3/1.7.0\r\n", " Loading requirement: gcc/8.3.1 cuda/10.2 nccl/2.6.4-1-cuda\r\n", " cudnn/7.6.5.32-cuda-10.2 intel-mkl/2020.1 magma/2.5.3-cuda\r\n", " openmpi/4.0.2-cuda\r\n", "+ srun python -u mnist-distributed.py --epochs 8 --batch-size 128\r\n", ">>> Training on 3 nodes and 12 processes, master node is r13i5n7\r\n", "- Process 0 corresponds to GPU 0 of node 0\r\n", "- Process 1 corresponds to GPU 1 of node 0\r\n", "- Process 2 corresponds to GPU 2 of node 0\r\n", "- Process 3 corresponds to GPU 3 of node 0\r\n", "- Process 4 corresponds to GPU 0 of node 1\r\n", "- Process 5 corresponds to GPU 1 of node 1\r\n", "- Process 10 corresponds to GPU 2 of node 2\r\n", "- Process 6 corresponds to GPU 2 of node 1\r\n", "- Process 7 corresponds to GPU 3 of node 1\r\n", "- Process 11 corresponds to GPU 3 of node 2\r\n", "- Process 8 corresponds to GPU 0 of node 2\r\n", "- Process 9 corresponds to GPU 1 of node 2\r\n", "Epoch [1/8], Step [200/500], Loss: 2.1815, Time data load: 1.044ms, Time training: 2.526ms\r\n", "Epoch [1/8], Step [400/500], Loss: 1.6933, Time data load: 1.042ms, Time training: 2.591ms\r\n", "Epoch [2/8], Step [200/500], Loss: 1.6502, Time data load: 1.045ms, Time training: 2.319ms\r\n", "Epoch [2/8], Step [400/500], Loss: 1.2381, Time data load: 1.036ms, Time training: 2.338ms\r\n", "Epoch [3/8], Step [200/500], Loss: 1.2961, Time data load: 1.026ms, Time training: 2.327ms\r\n", "Epoch [3/8], Step [400/500], Loss: 0.9567, Time data load: 1.039ms, Time training: 2.318ms\r\n", "Epoch [4/8], Step [200/500], Loss: 1.0635, Time data load: 1.046ms, Time training: 2.513ms\r\n", "Epoch [4/8], Step [400/500], Loss: 0.7811, Time data load: 1.041ms, Time training: 2.514ms\r\n", "Epoch [5/8], Step [200/500], Loss: 0.9117, Time data load: 1.036ms, Time training: 2.276ms\r\n", "Epoch [5/8], Step [400/500], Loss: 0.6651, Time data load: 1.040ms, Time training: 2.282ms\r\n", "Epoch [6/8], Step [200/500], Loss: 0.8072, Time data load: 1.342ms, Time training: 2.791ms\r\n", "Epoch [6/8], Step [400/500], Loss: 0.5828, Time data load: 1.042ms, Time training: 2.832ms\r\n", "Epoch [7/8], Step [200/500], Loss: 0.7293, Time data load: 1.040ms, Time training: 2.841ms\r\n", "Epoch [7/8], Step [400/500], Loss: 0.5212, Time data load: 1.040ms, Time training: 2.829ms\r\n", "Epoch [8/8], Step [200/500], Loss: 0.6686, Time data load: 1.043ms, Time training: 2.819ms\r\n", "Epoch [8/8], Step [400/500], Loss: 0.4729, Time data load: 1.036ms, Time training: 2.824ms\r\n", ">>> Training complete in: 0:00:16.697872\r\n" ] } ], "source": [ "# display output\n", "%cat mnist_pytorch_multinode.out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example of multi-node execution from a checkpoint" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Writing the submission batch script\n", "\n", "**Remember**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify for which allocation the consumed hours should be counted by adding the option ``--account=my_project@gpu`` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html)." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting batch_multinode.slurm\n" ] } ], "source": [ "%%writefile batch_multinode.slurm\n", "#!/bin/sh\n", "#SBATCH --job-name=mnist_pytorch_multinode\n", "#SBATCH --output=mnist_pytorch_multinode.out\n", "#SBATCH --error=mnist_pytorch_multinode.out\n", "#SBATCH --nodes=3\n", "#SBATCH --gres=gpu:4\n", "#SBATCH --ntasks-per-node=4\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "#SBATCH --time=00:10:00\n", "#SBATCH --qos=qos_gpu-dev\n", "\n", "# go into the submission directory \n", "cd ${SLURM_SUBMIT_DIR}\n", "\n", "# cleans out modules loaded in interactive and inherited by default\n", "module purge\n", "\n", "# loading modules\n", "module load pytorch-gpu/py3/1.7.0\n", "\n", "# echo of launched commands\n", "set -x\n", "\n", "# code execution\n", "srun python -u mnist-distributed.py --epochs 8 --batch-size 128 -c ./checkpoint/12GPU_8epoch.checkpoint" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 210608\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "sbatch: IDRIS: setting exclusive mode for the job.\n" ] } ], "source": [ "%%bash\n", "# submit job\n", "sbatch batch_multinode.slurm" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON) \n", " 210608 gpu_p13 mnist_py ssos040 R 0:19 3 r13i5n[7-8],r13i6n0 \n", " Done!\n" ] } ], "source": [ "# watch Slurm queue line until the job is done\n", "# execution should take about 1 minute\n", "import time\n", "sq = !squeue -u $USER -n mnist_pytorch_multinode\n", "print(sq[0])\n", "while len(sq) >= 2:\n", " print(sq[1],end='\\r')\n", " time.sleep(5)\n", " sq = !squeue -u $USER -n mnist_pytorch_multinode\n", "print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading pytorch-gpu/py3/1.7.0\r\n", " Loading requirement: gcc/8.3.1 cuda/10.2 nccl/2.6.4-1-cuda\r\n", " cudnn/7.6.5.32-cuda-10.2 intel-mkl/2020.1 magma/2.5.3-cuda\r\n", " openmpi/4.0.2-cuda\r\n", "+ srun python -u mnist-distributed.py --epochs 8 --batch-size 128 -c ./checkpoint/12GPU_8epoch.checkpoint\r\n", "- Process 3 corresponds to GPU 3 of node 0\r\n", ">>> Training on 3 nodes and 12 processes, master node is r13i5n7\r\n", "- Process 0 corresponds to GPU 0 of node 0\r\n", "- Process 1 corresponds to GPU 1 of node 0\r\n", "- Process 2 corresponds to GPU 2 of node 0\r\n", "- Process 4 corresponds to GPU 0 of node 1\r\n", "- Process 10 corresponds to GPU 2 of node 2\r\n", "- Process 7 corresponds to GPU 3 of node 1\r\n", "- Process 5 corresponds to GPU 1 of node 1\r\n", "- Process 6 corresponds to GPU 2 of node 1\r\n", "- Process 11 corresponds to GPU 3 of node 2\r\n", "- Process 8 corresponds to GPU 0 of node 2\r\n", "- Process 9 corresponds to GPU 1 of node 2\r\n", "Epoch [1/8], Step [200/500], Loss: 0.6191, Time data load: 1.076ms, Time training: 2.862ms\r\n", "Epoch [1/8], Step [400/500], Loss: 0.4334, Time data load: 1.072ms, Time training: 2.879ms\r\n", "Epoch [2/8], Step [200/500], Loss: 0.5782, Time data load: 1.065ms, Time training: 2.854ms\r\n", "Epoch [2/8], Step [400/500], Loss: 0.4005, Time data load: 1.064ms, Time training: 2.885ms\r\n", "Epoch [3/8], Step [200/500], Loss: 0.5435, Time data load: 1.067ms, Time training: 2.866ms\r\n", "Epoch [3/8], Step [400/500], Loss: 0.3728, Time data load: 1.065ms, Time training: 2.879ms\r\n", "Epoch [4/8], Step [200/500], Loss: 0.5138, Time data load: 1.065ms, Time training: 2.885ms\r\n", "Epoch [4/8], Step [400/500], Loss: 0.3489, Time data load: 1.072ms, Time training: 2.861ms\r\n", "Epoch [5/8], Step [200/500], Loss: 0.4879, Time data load: 1.066ms, Time training: 2.875ms\r\n", "Epoch [5/8], Step [400/500], Loss: 0.3277, Time data load: 1.070ms, Time training: 2.898ms\r\n", "Epoch [6/8], Step [200/500], Loss: 0.4651, Time data load: 1.031ms, Time training: 2.940ms\r\n", "Epoch [6/8], Step [400/500], Loss: 0.3088, Time data load: 1.030ms, Time training: 2.934ms\r\n", "Epoch [7/8], Step [200/500], Loss: 0.4450, Time data load: 1.029ms, Time training: 2.911ms\r\n", "Epoch [7/8], Step [400/500], Loss: 0.2920, Time data load: 1.026ms, Time training: 2.962ms\r\n", "Epoch [8/8], Step [200/500], Loss: 0.4270, Time data load: 1.027ms, Time training: 2.966ms\r\n", "Epoch [8/8], Step [400/500], Loss: 0.2766, Time data load: 1.029ms, Time training: 2.941ms\r\n", ">>> Training complete in: 0:00:17.237901\r\n" ] } ], "source": [ "# display output\n", "%cat mnist_pytorch_multinode.out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }