{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Data loading and pre-processing with PyTorch\n", "\n", "## Implementation\n", "\n", "*Notebook written by the IDRIS AI support team, March 2021*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " This document describes the method to use on Jean Zay to load and pre-process input data for a distributed training. It illustrates the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-torch-data-preprocessing-eng.html)\n", " and uses the [PyTorch documentation] (https://pytorch.org/docs/stable/data.html) as reference.\n", "\n", "This Notebook contains: \n", " * A [complete example](#exemple) of optimised loading.\n", " * [Comparison tests](#tests) of performance gains offered by each functionality described in the documentation (distribution, multiprocessing, prefetching, etc.)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Computing environment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook can be executed on any Jean Zay node but we advise using the jupyterhub front-end node (i.e. an *interactive* connection) to avoid consuming your allocation. In this case, the hostname is `jean-zay-srv2`:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "jean-zay-srv2\n" ] } ], "source": [ "!hostname" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't need to load any specific PyTorch module to run this notebook. The jobs will be submitted via Slurm in the `pytorch-gpu/py3/1.7.1` environment." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Complete example of optimised loading " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creation of the data loading Python script - optimised version" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting mnist_loader.py\n" ] } ], "source": [ "%%writefile mnist_loader.py \n", "import os\n", "import time\n", "import torch\n", "import torchvision\n", "import idr_torch # IDRIS package available in all PyTorch modules - interface with Slurm\n", "\n", "if idr_torch.rank == 0:\n", " print(f' --- Running on {idr_torch.size} GPU ---')\n", "\n", "global_start_time=time.time()\n", "\n", "# init multiprocess environment\n", "torch.cuda.set_device(idr_torch.local_rank)\n", "gpu = torch.device(\"cuda\")\n", "\n", "# define list of transformations to apply\n", "data_transform = torchvision.transforms.Compose([torchvision.transforms.Resize((300,300)),\n", " torchvision.transforms.ToTensor()])\n", "\n", "# load mnist dataset and apply transformations\n", "root = os.environ['DSDIR']\n", "start_time = time.time()\n", "mnist_dataset = torchvision.datasets.MNIST(root=root,\n", " transform=data_transform,\n", " download=False)\n", "end_time = time.time()\n", "if idr_torch.rank == 0:\n", " print(f'Rank {idr_torch.rank}: Loading dataset took {end_time - start_time} s')\n", "\n", "# define distributed sampler\n", "data_sampler = torch.utils.data.distributed.DistributedSampler(mnist_dataset,\n", " shuffle=True,\n", " num_replicas=idr_torch.size,\n", " rank=idr_torch.rank)\n", "\n", "# define DataLoader - optimised parameters\n", "batch_size = 128 # adjust batch size according to GPU type (16GB or 32GB in memory)\n", "drop_last = True # set to False if it represents important information loss\n", "num_workers = idr_torch.cpus_per_task # define number of CPU workers per process\n", "persistent_workers = True # set to False if CPU RAM must be released\n", "pin_memory = True # optimise CPU to GPU transfers\n", "non_blocking = True # activate asynchronism to speed up CPU/GPU transfers\n", "prefetch_factor = 2 # adjust number of batches to preload\n", "\n", "if idr_torch.rank == 0:\n", " print(f'------')\n", " print(f'Config: batch_size={batch_size}, drop_last={drop_last}') \n", " print(f' num_workers={num_workers}, persistent_workers={persistent_workers},')\n", " print(f' pin_memory={pin_memory}, non_blocking={non_blocking}, prefetch_factor={prefetch_factor}')\n", " print(f'------')\n", "\n", "dataloader = torch.utils.data.DataLoader(mnist_dataset,\n", " sampler=data_sampler,\n", " batch_size=batch_size,\n", " drop_last=drop_last,\n", " num_workers=num_workers,\n", " persistent_workers=persistent_workers,\n", " pin_memory=pin_memory,\n", " prefetch_factor=prefetch_factor\n", " )\n", "\n", "# loop over batches\n", "transfer_time=[]\n", "len_dataloader=len(dataloader)\n", "for i, (images, labels) in enumerate(dataloader):\n", " \n", " start_time = time.time()\n", " images = images.to(gpu, non_blocking=non_blocking)\n", " labels = labels.to(gpu, non_blocking=non_blocking)\n", " end_time = time.time()\n", " transfer_time.append(end_time - start_time)\n", "\n", "mean_transfer_time = sum(transfer_time) / len(transfer_time)\n", "max_transfer_time = max(transfer_time)\n", "max_idx = transfer_time.index(max(transfer_time))\n", "if idr_torch.rank == 0:\n", " print(f'Rank {idr_torch.rank}: Loop ended')\n", " print(f'Rank {idr_torch.rank}: Mean transfer time = {int(mean_transfer_time*10**6)} μs (max = {int(max_transfer_time*10**6)} μs, reached at {max_idx}th transfer)')\n", "\n", "\n", "global_end_time=time.time()\n", "if idr_torch.rank == 0:\n", " print(f'Rank {idr_torch.rank}: Global time = {global_end_time - global_start_time} s\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creation of the Slurm submission script\n", "\n", "**Reminder**: If your single project has both CPU and GPU hours, or if your login is attached to more than one project, you must specify for which allocation the consumed hours should be counted by adding the option `--account=my_project@gpu` as explained in the [IDRIS documentation](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-doc_account-eng.html)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting job.slurm\n" ] } ], "source": [ "%%writefile job.slurm\n", "#!/bin/bash\n", "#SBATCH --job-name=data_loader_pytorch-eng\n", "##SBATCH --account=XXX@v100\n", "#SBATCH --nodes=1\n", "#SBATCH --ntasks=4\n", "#SBATCH --gres=gpu:4\n", "#SBATCH --cpus-per-task=10\n", "#SBATCH --hint=nomultithread\n", "#SBATCH --time=00:30:00\n", "#SBATCH --output=data_loader_pytorch.out\n", "\n", "module load pytorch-gpu/py3/1.7.1\n", "\n", "srun python -u mnist_loader.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submission and execution of the optimised version" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import time\n", "from IPython.display import clear_output\n", "def display_slurm_queue():\n", " sq = !squeue -u $USER -n data_loader_pytorch-eng\n", " while len(sq) >= 2:\n", " clear_output(wait=True)\n", " for l in sq: print(l)\n", " time.sleep(10)\n", " sq = !squeue -u $USER -n data_loader_pytorch-eng\n", " print('\\n Done!')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 943007\n" ] } ], "source": [ "# submit job\n", "!sbatch job.slurm" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Done!\n" ] } ], "source": [ "display_slurm_queue()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading pytorch-gpu/py3/1.7.1\n", " Loading requirement: gcc/8.5.0 cuda/10.2 nccl/2.7.8-1-cuda\n", " cudnn/8.0.4.30-cuda-10.2 intel-mkl/2020.1 magma/2.5.3-cuda\n", " openmpi/4.0.2-cuda\n", " --- Running on 4 GPU ---\n", "Rank 0: Loading dataset took 0.11383652687072754 s\n", "------\n", "Config: batch_size=128, drop_last=True\n", " num_workers=10, persistent_workers=True,\n", " pin_memory=True, non_blocking=True, prefetch_factor=2\n", "------\n", "Rank 0: Loop ended\n", "Rank 0: Mean transfer time = 188 μs (max = 1305 μs, reached at 59th transfer)\n", "Rank 0: Global time = 5.309845685958862 s\n", "\n" ] } ], "source": [ "# display output\n", "!cat data_loader_pytorch.out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tests of the different optimisations " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We wish here to observe the impact of different parameters presented in the IDRIS documentation on the performance of data pre-processing. \n", "The parameters of interest in these tests are: \n", "* Number of GPUs (distribution) \n", "* Batch size \n", "* Number of workers (multiprocessing)\n", "* Memory pinning and asynchronism of CPU/GPU transfers\n", "* Pre-fetching batches for the GPUs\n", "\n", "It should be noted that these tests are run on the MNIST database which is small in size. This choice was made for educational purposes so that the execution and comparisons of performance can be done rapidly. The idea here is to be convinced of the benefits of each optimisation. The performance gain could potentialy be much greater on a larger database. \n", "\n", "-----" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " * Creation of files to store:\n", " * Slurm submission scripts\n", " * Python scripts for data loading\n", " * Standard outputs of executions " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mkdir: cannot create directory ‘slurm’: File exists\n", "mkdir: cannot create directory ‘scripts’: File exists\n", "mkdir: cannot create directory ‘logs’: File exists\n" ] } ], "source": [ "!mkdir slurm\n", "!mkdir scripts\n", "!mkdir logs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Preliminary creation of Python and Slurm scripts with variable parameters" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def create_new_scripts(ngpus=1, batch_size=8, num_workers=1, pin_memory=False, non_blocking=False, prefetch_factor=1):\n", " \n", " slurm_fname=f'slurm/job_{ngpus}_{batch_size}_{num_workers}_{pin_memory}_{prefetch_factor}.slurm'\n", " script_fname=f'scripts/mnist_loader_{ngpus}_{batch_size}_{num_workers}_{pin_memory}_{prefetch_factor}.py'\n", " \n", " # create slurm submission script with new number of gpus\n", " ref_file = open(\"job.slurm\",\"r\")\n", " new_file = open(slurm_fname,\"w\")\n", " for line in ref_file:\n", " if line.strip().startswith('#SBATCH --ntasks='):\n", " line = f'#SBATCH --ntasks={ngpus}\\n' \n", " new_file.write(line)\n", " elif line.strip().startswith('#SBATCH --gres=gpu:'):\n", " line = f'#SBATCH --gres=gpu:{ngpus}\\n'\n", " new_file.write(line)\n", " elif line.strip().startswith('#SBATCH --output='):\n", " line = f'#SBATCH --output=logs/data_loader_{ngpus}_{batch_size}_{num_workers}_{pin_memory}_{prefetch_factor}.out\\n'\n", " new_file.write(line)\n", " elif line.strip().startswith('srun'):\n", " line = f'srun python -u ' + script_fname\n", " new_file.write(line)\n", " else:\n", " new_file.write(line)\n", " \n", " # create python script with new parameters\n", " ref_file = open(\"mnist_loader.py\",\"r\")\n", " new_file = open(script_fname,\"w\")\n", " for line in ref_file:\n", " if line.strip().startswith('batch_size = '):\n", " line = f'batch_size = {batch_size}\\n'\n", " new_file.write(line)\n", " elif line.strip().startswith('num_workers = '):\n", " line = f'num_workers = {num_workers}\\n'\n", " new_file.write(line)\n", " elif line.strip().startswith('pin_memory = '):\n", " line = f'pin_memory = {pin_memory}\\n'\n", " new_file.write(line)\n", " elif line.strip().startswith('non_blocking = '):\n", " line = f'non_blocking = {non_blocking}\\n'\n", " new_file.write(line)\n", " elif line.strip().startswith('prefetch_factor = '):\n", " line = f'prefetch_factor = {prefetch_factor}\\n'\n", " new_file.write(line)\n", " else:\n", " new_file.write(line)\n", " \n", " return slurm_fname" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reference results of an under-optimised version" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The reference results correspond to an under-optimised version of the following parameters:\n", "* Number of GPUs = 1 \n", "* Batch size = 8 \n", "* Multiprocessing activated but only one worker \n", "* Non-pinned memory and synchronous CPU/GPU transfers\n", "* Pre-fetching of only one batch at a time" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "slurm/job_1_8_1_False_1.slurm\n", "Submitted batch job 943023\n" ] } ], "source": [ "# create and execute reference scripts\n", "slurm_fname = create_new_scripts()\n", "print(slurm_fname)\n", "!sbatch $slurm_fname" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Done!\n" ] } ], "source": [ "display_slurm_queue()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading pytorch-gpu/py3/1.7.1\n", " Loading requirement: gcc/8.5.0 cuda/10.2 nccl/2.7.8-1-cuda\n", " cudnn/8.0.4.30-cuda-10.2 intel-mkl/2020.1 magma/2.5.3-cuda\n", " openmpi/4.0.2-cuda\n", " --- Running on 1 GPU ---\n", "Rank 0: Loading dataset took 0.050215721130371094 s\n", "------\n", "Config: batch_size=8, drop_last=True\n", " num_workers=1, persistent_workers=True,\n", " pin_memory=False, non_blocking=False, prefetch_factor=1\n", "------\n", "Rank 0: Loop ended\n", "Rank 0: Mean transfer time = 1066 μs (max = 1399016 μs, reached at 0th transfer)\n", "Rank 0: Global time = 64.90769052505493 s\n", "\n" ] } ], "source": [ "!cat logs/data_loader_1_8_1_False_1.out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Number of GPUs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Estimate of time gain when the number of GPUs is increased" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 943044\n", "Submitted batch job 943045\n", "Submitted batch job 943053\n" ] } ], "source": [ "# create and execute scripts with increasing number of gpus (ngpus = 1 already done in ref job)\n", "for ngpus in [2, 3, 4]:\n", " slurm_fname = create_new_scripts(ngpus=ngpus)\n", " !sbatch $slurm_fname" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Done!\n" ] } ], "source": [ "display_slurm_queue()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> Ngpus = 1\n", "Rank 0: Global time = 64.90769052505493 s\n", ">>> Ngpus = 2\n", "Rank 0: Global time = 32.32062292098999 s\n", ">>> Ngpus = 3\n", "Rank 0: Global time = 22.028196811676025 s\n", ">>> Ngpus = 4\n", "Rank 0: Global time = 17.05109715461731 s\n" ] } ], "source": [ "%%bash\n", "for i in 1 2 3 4\n", "do\n", " echo \">>> Ngpus = $i\" \n", " grep \"Global time\" logs/data_loader_${i}_8_1_False_1.out\n", "done" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Batch size" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Estimate of time gain when the batch size is increased" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 943066\n", "Submitted batch job 943067\n", "Submitted batch job 943068\n", "Submitted batch job 943070\n" ] } ], "source": [ "# create and execute scripts with increasing batch size (batch_size=8 already done in ref job)\n", "for batch_size in [16, 32, 64, 128]:\n", " slurm_fname = create_new_scripts(batch_size=batch_size)\n", " !sbatch $slurm_fname" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Done!\n" ] } ], "source": [ "display_slurm_queue()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> batch_size = 8\n", "Rank 0: Global time = 64.90769052505493 s\n", ">>> batch_size = 16\n", "Rank 0: Global time = 56.986307859420776 s\n", ">>> batch_size = 32\n", "Rank 0: Global time = 56.82447123527527 s\n", ">>> batch_size = 64\n", "Rank 0: Global time = 58.25776219367981 s\n", ">>> batch_size = 128\n", "Rank 0: Global time = 57.39804220199585 s\n" ] } ], "source": [ "%%bash\n", "for size in 8 16 32 64 128\n", "do\n", " echo \">>> batch_size = $size\" \n", " grep \"Global time\" logs/data_loader_1_${size}_1_False_1.out\n", "done" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multiprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Estimate of the time gain when the number of workers is increased" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 943112\n", "Submitted batch job 943113\n", "Submitted batch job 943114\n", "Submitted batch job 943115\n", "Submitted batch job 943117\n" ] } ], "source": [ "# create and execute scripts with increasing number of workers (num_workers=1 already done in ref job)\n", "for num_workers in [2,4,6,8,10]:\n", " slurm_fname = create_new_scripts(num_workers=num_workers)\n", " !sbatch $slurm_fname" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Done!\n" ] } ], "source": [ "display_slurm_queue()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> num_workers = 1\n", "Rank 0: Global time = 64.90769052505493 s\n", ">>> num_workers = 2\n", "Rank 0: Global time = 32.45805883407593 s\n", ">>> num_workers = 4\n", "Rank 0: Global time = 17.409164667129517 s\n", ">>> num_workers = 6\n", "Rank 0: Global time = 17.93054223060608 s\n", ">>> num_workers = 8\n", "Rank 0: Global time = 18.30246901512146 s\n", ">>> num_workers = 10\n", "Rank 0: Global time = 19.45599937438965 s\n" ] } ], "source": [ "%%bash\n", "for n in 1 2 4 6 8 10\n", "do\n", " echo \">>> num_workers = $n\" \n", " grep \"Global time\" logs/data_loader_1_8_${n}_False_1.out\n", "done" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### CPU/GPU transfers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Estimate of the time gain when the memory is pinned and asynchronism is activated" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 943139\n" ] } ], "source": [ "# create and execute scripts with optimised CPU/GPU transfers (pin_memory=False already done in ref job)\n", "slurm_fname = create_new_scripts(pin_memory=True, non_blocking=True)\n", "!sbatch $slurm_fname" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Done!\n" ] } ], "source": [ "display_slurm_queue()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> Pin memory = False \n", "Rank 0: Mean transfer time = 1066 μs (max = 1399016 μs, reached at 0th transfer)\n", "Rank 0: Global time = 64.90769052505493 s\n", ">>> Pin memory = True \n", "Rank 0: Mean transfer time = 84 μs (max = 761 μs, reached at 0th transfer)\n", "Rank 0: Global time = 68.2191755771637 s\n" ] } ], "source": [ "%%bash\n", "echo \">>> Pin memory = False \"\n", "grep \"Mean transfer time\" logs/data_loader_1_8_1_False_1.out\n", "grep \"Global time\" logs/data_loader_1_8_1_False_1.out\n", "echo \">>> Pin memory = True \"\n", "grep \"Mean transfer time\" logs/data_loader_1_8_1_True_1.out\n", "grep \"Global time\" logs/data_loader_1_8_1_True_1.out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pre-fetching batches" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Estimate of the time gain when the number of pre-fetched batches is increased" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Submitted batch job 943199\n", "Submitted batch job 943200\n", "Submitted batch job 943201\n" ] } ], "source": [ "# create and execute scripts with increasing prefecth_factor (prefetch_factor=1 already done in ref job)\n", "for prefetch_factor in [2,3,4]:\n", " slurm_fname = create_new_scripts(prefetch_factor=prefetch_factor)\n", " !sbatch $slurm_fname" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Done!\n" ] } ], "source": [ "display_slurm_queue()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> prefetch_factor = 1\n", "Rank 0: Global time = 64.90769052505493 s\n", ">>> prefetch_factor = 2\n", "Rank 0: Global time = 52.71680998802185 s\n", ">>> prefetch_factor = 3\n", "Rank 0: Global time = 52.24641942977905 s\n", ">>> prefetch_factor = 4\n", "Rank 0: Global time = 52.70093059539795 s\n" ] } ], "source": [ "%%bash\n", "for factor in 1 2 3 4\n", "do\n", " echo \">>> prefetch_factor = $factor\" \n", " grep \"Global time\" logs/data_loader_1_8_1_False_${factor}.out\n", "done" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.0" } }, "nbformat": 4, "nbformat_minor": 4 }