{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Pytorch: Model parallelism on two GPUs\n", "\n", "*Notebook written by the IDRIS AI support team, May 2021*\n", "\n", "This notebook demonstrates how to implement model parallelism on Jean Zay as presented in [PyTorch: Multi GPU model parallelism](http://www.idris.fr/eng/ia/model-parallelism-pytorch-eng.html). For the sake of clarity, this demo shows a simple training (no validation step) using the **Resnet101** model on **Imagenet** dataset.\n", "\n", "Be advised that only the pipeline version of model parallelism is described.\n", "\n", "It consists here of:\n", "\n", "* Distribution of the model layers on two GPUs.\n", "* Data loading on the GPU containing the first layers of the model, label loading on the other GPU.\n", "* Slurm file setup.\n", "\n", "----\n", "\n", "\n", "## Environment checks\n", " \n", "This notebook is intended for execution from a Jean Zay front end. The hostname must be jean-zay[1-5]. A PyTorch module must be loaded beforehand in order for this Notebook to function correctly. For example, the ``pytorch-gpu/py3/1.8.0`` module.\n", " \n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "jean-zay1\r\n" ] } ], "source": [ "!hostname" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?1h\u001b=\r", "Currently Loaded Modulefiles:\u001b[m\r\n", " 1) gcc/8.3.1 4) cudnn/8.0.4.30-cuda-10.2 7) \u001b[4mopenmpi/4.0.5-cuda\u001b[0m \u001b[m\r\n", " 2) cuda/10.2 5) intel-mkl/2020.4 8) pytorch-gpu/py3/1.8.0 \u001b[m\r\n", " 3) nccl/2.8.3-1-cuda 6) magma/2.5.4-cuda \u001b[m\r\n", "\r", "\u001b[K\u001b[?1l\u001b>" ] } ], "source": [ "!module list" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Additional setup consists in the creation of several folders: for the Slurm configuration, the logs and model saving.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mkdir: cannot create directory ‘slurm’: File exists\r\n" ] } ], "source": [ "!mkdir slurm" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mkdir: cannot create directory ‘log’: File exists\r\n" ] } ], "source": [ "!mkdir log" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mkdir: cannot create directory ‘checkpoint’: File exists\r\n" ] } ], "source": [ "!mkdir checkpoint" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model setup\n", "\n", "\n", "We are relying on `torchvision` model zoo from which we import **Resnet101**.\n", "\n", "The required adaptations are:\n", "\n", "* Distribution of the model layers on all the available GPUs and communication setup between GPUs (referred as `dev0` and `dev1`).\n", "* Setup of the pipeline, ie. the data batch is split to allow concurrent execution on both GPUs (see variable `split_size` and `forward()` function).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model creation\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting resnet.py\n" ] } ], "source": [ "%%writefile resnet.py\n", "\n", "import torch\n", "import torch.nn as nn\n", "from torchvision.models.resnet import ResNet, Bottleneck\n", "\n", "num_classes = 1000\n", "\n", "\n", "class PipelinedResnet(ResNet):\n", " def __init__(self, dev0, dev1, split_size=8, *args, **kwargs):\n", " super(PipelinedResnet, self).__init__(\n", " Bottleneck, [3, 4, 23, 3], num_classes=num_classes, *args, **kwargs)\n", " # dev0 and dev1 point to the GPU (usually gpu:0 and gpu:1)\n", " self.dev0 = dev0\n", " self.dev1 = dev1\n", " self.split_size = split_size\n", "\n", " self.seq0 = nn.Sequential(\n", " self.conv1,\n", " self.bn1,\n", " self.relu,\n", " self.maxpool,\n", " self.layer1,\n", " self.layer2\n", " ).to(self.dev0) # sends the first sequence of the model to the first GPU\n", "\n", " self.seq1 = nn.Sequential(\n", " self.layer3,\n", " self.layer4,\n", " self.avgpool,\n", " ).to(self.dev1) # sends the second sequence of the model to the second GPU\n", "\n", " self.fc.to(self.dev1) # last layer is on the second GPU\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Communication between layers \n", "\n", "In simple parallelism mode (no pipeline), the `forward()` function describes the data flow between the two GPUs, from input `x` to the last model layer:\n", "\n", " \n", "```py \n", "def forward(self, x):\n", " x= self.seq0(x) # apply first sequence of the model on input x\n", " x= x.to(self.dev1) # send the intermediary result to the second GPU\n", " x = self.seq1(x) # apply second sequence of the model to x\n", " return self.fc(x.view(x.size(0), -1))\n", "```\n", "\n", "The pipeline mode depends on the `split_size` variable: this variable describes how many chunks of data go through the GPUs. Here, we are using the default value but it should be set on a case-by-case basis, ideally through a benchmarks (as presented on [Pytorch documentation](https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html#speed-up-by-pipelining-inputs)). \n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to resnet.py\n" ] } ], "source": [ "%%writefile -a resnet.py\n", "\n", " def forward(self, x):\n", " # split setup for x, containing a batch of (image, label) as a tensor\n", " splits = iter(x.split(self.split_size, dim=0))\n", " s_next = next(splits)\n", " s_prev = self.seq0(s_next).to(self.dev1)\n", " ret = []\n", "\n", " for s_next in splits:\n", " # A. s_prev runs on dev1\n", " s_prev = self.seq1(s_prev)\n", " ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))\n", "\n", " # B. s_next runs on dev0, which can run concurrently with A\n", " s_prev = self.seq0(s_next).to(self.dev1)\n", "\n", " s_prev = self.seq1(s_prev)\n", " ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))\n", "\n", " return torch.cat(ret)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset creation\n", "\n", "The sub-dataset \"validation\" from Imagenet (`val`) is used since it allows short training loops (containing only 50.000 samples). In the same vein, we only use the minimum amount of transformations on the data. \n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting dataset.py\n" ] } ], "source": [ "%%writefile dataset.py\n", "\n", "import os\n", "\n", "import idr_torch # see http://www.idris.fr/jean-zay/gpu/jean-zay-gpu-torch-multi.html\n", "import torch\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "\n", "def get_dataset(ds_name='val', batch_size=256, num_workers=8):\n", " transform = transforms.Compose([\n", " transforms.RandomResizedCrop(224), # Random resize - Data Augmentation\n", " transforms.ToTensor(), # convert the PIL Image to a tensor\n", " ])\n", " # dataset creation\n", " train_ds = torchvision.datasets.ImageNet(root=os.environ['DSDIR'] + '/imagenet/RawImages',\n", " transform=transform,\n", " split=ds_name) \n", " # data loader creation with a couple of optimization (pin_memory and prefetch)\n", " train_loader = torch.utils.data.DataLoader(dataset=train_ds,\n", " batch_size=batch_size,\n", " shuffle=True, num_workers=num_workers,\n", " pin_memory=True, drop_last=True, \n", " prefetch_factor=2) \n", " \n", " return train_loader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Main code\n", "\n", "### Imports \n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting demo_model_distribution.py\n" ] } ], "source": [ "%%writefile demo_model_distribution.py\n", "\n", "import argparse\n", "import time\n", "from datetime import timedelta\n", "\n", "import idr_torch # see http://www.idris.fr/jean-zay/gpu/jean-zay-gpu-torch-multi.html\n", "import torch\n", "import torch.distributed as dist\n", "import torch.nn as nn\n", "\n", "from resnet import PipelinedResnet\n", "import dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Useful functions" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to demo_model_distribution.py\n" ] } ], "source": [ "%%writefile -a demo_model_distribution.py\n", "\n", "def parse_args():\n", " parser = argparse.ArgumentParser()\n", " parser.add_argument('-b', '--batch-size', default=128, type=int,\n", " help='batch size')\n", " parser.add_argument('-e', '--epochs', default=2, type=int, metavar='N',\n", " help='number of total epochs to run')\n", " parser.add_argument('-k', '--checkpoint-path', default='checkpoint/test_', type=str,\n", " help='relative path where model can be saved')\n", " parser.add_argument('-r', '--resume-model', default='', type=str,\n", " help='relative path to model to load before resuming training')\n", " args = parser.parse_args()\n", " return args\n", "\n", "\n", "def convert(seconds):\n", " return time.strftime(\"%H:%M:%S\", time.gmtime(seconds))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training loop\n", "\n", "Now that the model is split on two GPUs, it is time to define to which GPU the input and output should be sent:\n", "\n", "* Images from samples will go to the first GPU.\n", "* Labels from samples must be on the same GPU as the output of the models (ie. the prediction), hence the second GPU.\n", "\n", "Saving the model (see function `training()`) is the same as when using only one GPU.\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to demo_model_distribution.py\n" ] } ], "source": [ "%%writefile -a demo_model_distribution.py\n", "\n", "def train(model, optimizer, criterion, train_loader, batch_size, gpu):\n", " model.train()\n", " if idr_torch.rank == 0:\n", " running_train_loss = 0.0\n", " running_train_corrects = 0\n", " for batch_counter, (images, labels) in enumerate(train_loader):\n", " # images are sent to the first GPU\n", " images = images.to(gpu[0], non_blocking=True)\n", " # zero the parameter gradients\n", " optimizer.zero_grad()\n", " # forward\n", " with torch.set_grad_enabled(True):\n", " outputs = model(images)\n", " # labels (ground truth) are sent to the GPU where the outputs of the model\n", " # reside, which in this case is the second GPU \n", " labels = labels.to(outputs.device, non_blocking=True)\n", " _, preds = torch.max(outputs, 1)\n", " loss = criterion(outputs, labels)\n", " # backward + optimize only if in training phase\n", " loss.backward()\n", " optimizer.step()\n", " if idr_torch.rank == 0:\n", " # statistics\n", " running_train_loss += loss.item()\n", " running_train_corrects += torch.sum(preds == labels.data).item()\n", " if idr_torch.rank == 0:\n", " epoch_loss = running_train_loss / (batch_counter + 1)\n", " epoch_acc = 100.0 * running_train_corrects / ((batch_counter + 1) * batch_size)\n", " print(f'Epoch Train Loss: {epoch_loss:.2f} Acc: {epoch_acc:.2f}')\n", " \n", "def training(model, train_loader, epochs, batch_size, gpu, checkpoint_path):\n", " criterion = nn.CrossEntropyLoss()\n", " optimizer = torch.optim.SGD(model.parameters(), 1e-3)\n", " scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)\n", "\n", " if idr_torch.rank == 0:\n", " total_time = time.time()\n", "\n", " for epoch in range(epochs):\n", " if idr_torch.rank == 0:\n", " print(f\"Epoch {epoch + 1}/{epochs}\")\n", " t = time.time()\n", "\n", " train(model, optimizer, criterion, train_loader, batch_size, gpu)\n", " scheduler.step()\n", "\n", " if idr_torch.rank == 0:\n", " duration = time.time() - t\n", " print(f\"\\t Duration : {duration:.2f}\")\n", " print(f\"Saving model at epoch {epoch}\")\n", " name = f\"{checkpoint_path}{epoch}.pt\"\n", " torch.save(model.state_dict(), name)\n", "\n", " if idr_torch.rank == 0:\n", " total_time_elapsed = time.time() - total_time\n", "\n", " if idr_torch.rank == 0:\n", " print(\"-------------------------------------\")\n", " print(f\"Total time: {total_time_elapsed:.2f} \\t {convert(total_time_elapsed)}\")\n", " print(\"-------------------------------------\")\n", " time.sleep(2) # used only to get a clean log\n", " for g in gpu:\n", " print(f\"Device id {g} max memory usage: {torch.cuda.max_memory_allocated(g) // (1024 * 1024)} GB\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup \n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Appending to demo_model_distribution.py\n" ] } ], "source": [ "%%writefile -a demo_model_distribution.py\n", "\n", "def get_model(gpu, load_model=None):\n", " mp_model = PipelinedResnet(dev0=gpu[0], dev1=gpu[1])\n", " if load_model:\n", " print(f\"Loading model {load_model}\")\n", " mp_model.load_state_dict(torch.load(load_model))\n", " return mp_model\n", "\n", "def main():\n", " args = parse_args()\n", " batch_size, epochs, checkpoint_path = args.batch_size, args.epochs, args.checkpoint_path\n", " if idr_torch.rank == 0:\n", " print(f\"Current setup:\")\n", " print(f\"\\tTraining on {batch_size} samples per batch, for {epochs} epoch(s).\")\n", "\n", " gpu = [0, 1]\n", " torch.cuda.set_device(0)\n", " train_loader = dataset.get_dataset(batch_size=batch_size)\n", " time.sleep(2) # used only to get a clean log\n", " model = get_model(gpu=gpu, load_model=args.resume_model)\n", " time.sleep(2) # used only to get a clean log\n", " training(model, train_loader, epochs, batch_size, gpu, checkpoint_path)\n", "\n", "\n", "if __name__ == '__main__':\n", " main()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Demonstration\n", "\n", "We will perform two runs (each using two GPUs for the model):\n", "\n", "* The first execution creates checkpoints, ie. saves the model's current state at each epoch.\n", "* The second execution loads a checkpoint from first execution, then resumes training.\n", "\n", "On a quadri-GPU node, we need to reserve half of the GPU cards and only one task. In this particular setup, half the memory of the node can be reserved through the option `#SBATCH --cpus-per-task` (since half of the GPUs are requested).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### First execution \n", "\n", "It is important to correctly define the number of GPUs and tasks per node in the Slurm file.\n", "\n", "**Reminder**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify for which allocation the consumed hours should be counted by adding the option `--account=my_project@gpu` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html).\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting slurm/prime_run.slurm\n" ] } ], "source": [ "%%writefile slurm/prime_run.slurm\n", "#!/bin/bash\n", "#SBATCH --job-name=save\n", "#SBATCH --output=log/prime_run.out\n", "#SBATCH --error=log/prime_run.err\n", "#SBATCH --gres=gpu:2\n", "#SBATCH --nodes=1\n", "#SBATCH --ntasks-per-node=1\n", "#SBATCH --hint=nomultithread\n", "#SBATCH --time=00:30:00\n", "#SBATCH --qos=qos_gpu-dev\n", "#SBATCH -C v100-16g\n", "#SBATCH --cpus-per-task=20 # on a node with 4 GPU, this is half of the available CPU (and memory)\n", "\n", "\n", "## load Pytorch module\n", "module purge\n", "module load pytorch-gpu/py3/1.8.0\n", "\n", "## launch script on every node\n", "set -x\n", "time srun python -u demo_model_distribution.py -b 128 -e 2 -k \"checkpoint/test_\"\n", "date" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 1103082\r\n" ] } ], "source": [ "!sbatch 'slurm/prime_run.slurm'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`idr_pytools` is a small script which enables access to Slurm variables. For more information, see [Python scripts for automated execution of GPU jobs](http://www.idris.fr/eng/jean-zay/gpu/scripts-python-execution-travaux-gpu-eng.html).\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)\n", " 1103082 gpu_p13 save ssos021 R 13:41 1 r10i1n0\n", "\n", " Done!\n" ] } ], "source": [ "from idr_pytools import display_slurm_queue\n", "\n", "job_name = 'save' \n", "display_slurm_queue(job_name)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading pytorch-gpu/py3/1.8.0\r\n", " Loading requirement: gcc/8.3.1 cuda/10.2 nccl/2.8.3-1-cuda\r\n", " cudnn/8.0.4.30-cuda-10.2 intel-mkl/2020.4 magma/2.5.4-cuda\r\n", " openmpi/4.0.5-cuda\r\n", "+ srun python -u demo_model_distribution.py -b 128 -e 2 -k checkpoint/test_\r\n", "\r\n", "real\t13m40.827s\r\n", "user\t0m0.014s\r\n", "sys\t0m0.009s\r\n", "+ date\r\n" ] } ], "source": [ "!cat \"log/prime_run.err\"" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Current setup:\r\n", "\tTraining on 128 samples per batch, for 2 epoch(s).\r\n", "Epoch 1/2\r\n", "Epoch Train Loss: 7.00 Acc: 0.13\r\n", "\t Duration : 403.34\r\n", "Saving model at epoch 0\r\n", "Epoch 2/2\r\n", "Epoch Train Loss: 6.93 Acc: 0.12\r\n", "\t Duration : 393.20\r\n", "Saving model at epoch 1\r\n", "-------------------------------------\r\n", "Total time: 797.47 \t 00:13:17\r\n", "-------------------------------------\r\n", "Device id 0 max memory usage: 7950 GB\r\n", "Device id 1 max memory usage: 8224 GB\r\n", "Mon Apr 26 17:36:01 CEST 2021\r\n" ] } ], "source": [ "!cat \"log/prime_run.out\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Second execution: Load a checkpoint before resuming training\n", "\n", "The model parameters saved during the previous execution will be loaded during the model creation step.\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "check_0.pt check_1.pt\ttest_0.pt test_1.pt\r\n" ] } ], "source": [ "!ls \"checkpoint\"" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting slurm/second_run.slurm\n" ] } ], "source": [ "%%writefile slurm/second_run.slurm\n", "#!/bin/bash\n", "#SBATCH --job-name=load\n", "#SBATCH --output=log/second_run.out\n", "#SBATCH --error=log/second_run.err\n", "#SBATCH --gres=gpu:2\n", "#SBATCH --nodes=1\n", "#SBATCH --ntasks-per-node=1\n", "#SBATCH --hint=nomultithread\n", "#SBATCH --time=00:30:00\n", "#SBATCH --qos=qos_gpu-dev\n", "#SBATCH -C v100-16g\n", "#SBATCH --cpus-per-task=20\n", "\n", "\n", "## load Pytorch module\n", "module purge\n", "module load pytorch-gpu/py3/1.8.0\n", "\n", "## launch script on every node\n", "set -x\n", "time srun python -u demo_model_distribution.py -b 128 -e 2 -k \"checkpoint/check_\" -r \"checkpoint/test_1.pt\"\n", "date" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 1103323\r\n" ] } ], "source": [ "!sbatch 'slurm/second_run.slurm'" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)\n", " 1103323 gpu_p13 load ssos021 R 13:32 1 r10i1n0\n", "\n", " Done!\n" ] } ], "source": [ "job_name = 'load' \n", "display_slurm_queue(job_name)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading pytorch-gpu/py3/1.8.0\r\n", " Loading requirement: gcc/8.3.1 cuda/10.2 nccl/2.8.3-1-cuda\r\n", " cudnn/8.0.4.30-cuda-10.2 intel-mkl/2020.4 magma/2.5.4-cuda\r\n", " openmpi/4.0.5-cuda\r\n", "+ srun python -u demo_model_distribution.py -b 128 -e 2 -k checkpoint/check_ -r checkpoint/test_1.pt\r\n", "\r\n", "real\t13m30.334s\r\n", "user\t0m0.014s\r\n", "sys\t0m0.008s\r\n", "+ date\r\n" ] } ], "source": [ "!cat \"log/second_run.err\"" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Current setup:\r\n", "\tTraining on 128 samples per batch, for 2 epoch(s).\r\n", "Loading model checkpoint/test_1.pt\r\n", "Epoch 1/2\r\n", "Epoch Train Loss: 6.91 Acc: 0.11\r\n", "\t Duration : 397.74\r\n", "Saving model at epoch 0\r\n", "Epoch 2/2\r\n", "Epoch Train Loss: 6.90 Acc: 0.16\r\n", "\t Duration : 392.86\r\n", "Saving model at epoch 1\r\n", "-------------------------------------\r\n", "Total time: 791.35 \t 00:13:11\r\n", "-------------------------------------\r\n", "Device id 0 max memory usage: 7950 GB\r\n", "Device id 1 max memory usage: 8219 GB\r\n", "Mon Apr 26 17:49:42 CEST 2021\r\n" ] } ], "source": [ "!cat \"log/second_run.out\"" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 2 }