PyTorch: Database loading for the distributed learning of a model

In this page, we describe the management of Datasets and DataLoaders for the PyTorch distributed learning model. We are focusing here on issues introduced in the main page on data loading. We will describe the usage of:

This page concludes with the presentation of a complete example of optimised data loading, and implementation on Jean Zay via a Jupyter Notebook.

Preliminary comment : In this documentation, we do not address IterableDataset type objects which enable treating databases with unknown structures. This type of object is traversed with the help of an iterator which has a mechanism to acquire the next element (if there is one). This mechanism prevents the direct usage of certain functionalities of the « DataLoader » such as shuffling and multiprocessing which are based on index manipulation and require an overall view of the database.

Datasets

Pre-defined datasets in PyTorch

PyTorch offers a collection of pre-defined Datasets in the torchvision, torchaudio and torchtext libraries. These libraries manage the creation of a Dataset object for the standard databases listed in the official documentation:

Loading a database is done via the Datasets module. For example, loading the ImageNet database can be done with torchvision in the following way:

import torchvision
 
# load imagenet dataset stored in DSDIR
root = os.environ['DSDIR']+'/imagenet/RawImages'
imagenet_dataset = torchvision.datasets.ImageNet(root=root)

Most of the time, it is possible at the loading to differentiate the data dedicated to the training and the data dedicated to the validation. For example, for the ImageNet database:

import torchvision
 
# load imagenet dataset stored in DSDIR
root = os.environ['DSDIR']+'/imagenet/RawImages'
 
## load data for training
imagenet_train_dataset = torchvision.datasets.ImageNet(root=root,
                                                       split='train')
## load data for validation
imagenet_val_dataset = torchvision.datasets.ImageNet(root=root,
                                                     split='val')

Each loading function then proposes functionalities specific to the databases (data quality, extraction of part of the data, etc). For more information, please consult the official documentation.

Comments :

  • The torchvision library contains a generic loading function: torchvision.Datasets.ImageFolder. It is adapted to all image databases on condition that the database is stored in a certain format (see the official documentation for more information).
  • Certain functions propose downloading the databases on line using the argument download=True. We remind you that the JeanZay compute nodes do not have access to internet and such operations must be done from a front end or a pre-/post-processing node. We also remind you that large public databases are already available in the Jean Zay DSDIR common space. This space can be enriched upon request to the IDRIS support team (assist@idris.fr).

Custom Datasets

It is possible to create your own Dataset classes by defining the following three functions:

  • __init__ initialises the variable containing the data to process.
  • __len__ returns the length of the database.
  • __getitem__ returns the data corresponding to a given index.

For example:

class myDataset(Dataset):
    def __init__(self, data):
	# Initialise dataset from source dataset
	self.data = data
 
    def __len__(self):
	# Return length of the dataset
	return len(self.data)
 
    def __getitem__(self, idx):
	# Return one element of the dataset according to its index
	return self.data[idx]

Transformations

Predefined transformations in PyTorch

The torchvision, torchtext and torchaudio libraries offer a selection of pre-implemented transformations, accessible via the tranforms module of the datasets class. These transformations are listed in the official documentation:

The transformation instructions are ported by the Dataset object. It is possible to cumulate different types of transformations by using the transforms.Compose() function. For example, to resize all of the images of the ImageNet database:

import torchvision
 
# define list of transformations to apply
data_transform = torchvision.transforms.Compose([torchvision.transforms.Resize((300,300)),
                                                 torchvision.transforms.ToTensor()])
 
# load imagenet dataset and apply transformations
root = os.environ['DSDIR']+'/imagenet/RawImages'
imagenet_dataset = torchvision.datasets.ImageNet(root=root,
                                                 transform=data_transform)

Comment : The transforms.ToTensor() transformation enables converting a PIL image or a NumPy array to tensor. To apply transformations to a custom Dataset, you must modify it as shown in the following example:

class myDataset(Dataset):
 
    def __init__(self, data, transform=None):
	# Initialise dataset from source dataset
	self.data = data
        self.transform = transform
 
    def __len__(self):
	# Return length of the dataset
	return len(self.data)
 
    def __getitem__(self, idx):
	# Return one element of the dataset according to its index
        x = self.data[idx]
 
        # apply transformation if requested
	if self.transform:
            x = self.transform(x)
 
	return x

Custom transformations

It is also possible to create your own transformations by defining the callable functions and communicating them directly to transforms.Compose(). For example, you can define transformations such as Add and Mult (multiply) in the following way:

# define Add tranformation
class Add(object):
    def __init__(self, value):
	self.value = value
    def __call__(self, sample):
        # add a constant to the data
	return sample + self.value
 
# define Mult transformation
class Mult(object):
    def __init__(self, value):
	self.value = value
    def __call__(self, sample):
        # multiply the data by a constant
	return sample * self.value
 
# define list of transformations to apply
data_transform = transforms.Compose([Add(2),Mult(3)])

DataLoaders

A DataLoader object is a dataset wrapper which enables data structuring (batch creation), pre-processing (shuffling, transforming) and data transmission to the GPUs for the training phase.

The DataLoader is an object of the torch.utils.data.DataLoader class:

import torch
 
# define DataLoader for a given dataset
dataloader = torch.utils.data.DataLoader(dataset)

Optimisation of data loading parameters

The configurable arguments of the DataLoader class are the following:

DataLoader(dataset,
           shuffle=False,
           sampler=None, batch_sampler=None, collate_fn=None,
           batch_size=1, drop_last=False,
           num_workers=0, worker_init_fn=None, persistent_workers=False,  
           pin_memory=False, timeout=0,
           prefetch_factor=2, *
          )

Randomized processing of input data

The shuffle=True argument enables activating randomized input data. Caution: This functionality must be delegated to the sampler if you are using a distributed sampler (see following point).

Data distribution on more than one process for distributed learning

With the sampler argument, you can specify the type of database sampling which you wish to implement. To distribute the data on more than one process, you must use the DistributedSampler provided by the PyTorch torch.utils.data.distributed class. For example:

import idr_torch # IDRIS package available in all PyTorch modules
 
# define distributed sampler
data_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                               shuffle=True,
                                                               num_replicas=idr_torch.size,
                                                               rank=idr_torch.rank)

This sampler takes in argument the activation order of the shuffling, the number of available processes (num_replicas) and the local rank. The shuffling step is delegated to the sampler so that it can be processed in parallel. The number of processes and the local rank are determined from the Slurm environment in which the training script was run. We are using the idr_torch library here to recover this information. This library was developed by IDRIS and is present in all the PyTorch modules on Jean Zay.

Comment : The DistributedSampler is adapted to the torch.nn.parallel.DistributedDataParallel strategy which we describe on this page.

Optimisation of the use of resources during the learning

The size of a batch (defined by the argument batch_size) is optimal if it allows good usage of the computing resources; that is, if the memory of each GPU is maximally solicited and the workload is distributed equally between the GPUs.

It can happen that the quantity of input data is not a multiple of the requested batch size. In this case, to prevent the DataLoader from generating an « incomplete» batch with the last data extracted, thereby causing an imbalance in the GPU workload, it is possible to command the DataLoader to ignore this last batch by using the argument drop_last=True. This can, however, represent a loss of information which needs to be estimated beforehand.

Transfer/computation overlapping

It is possible to optimise the batch transfers from CPU to GPU by generating transfer/computation overlapping. A first optimisation is done during the training and consists of pre-loading the next batches to be processed. The quantity of pre-loaded batches is controlled by the prefetch_factor. By default, this value is set at 2 which is suitable in most cases.

A second optimisation consists of requesting the DataLoader to store the batches on the CPU in pinned memory (pin_memory=True). With this strategy, certain copying steps can be avoided during the transfers from to CPU to GPU. It also enables using the non_blocking=True asynchronous mechanism during the call to the .to() or .device() transfer functions.

Acceleration of data pre-processing (transformations)

The data pre-processing (transformations) step uses a large amount of CPU resources. To accelerate this, it is possible to parallelize the operations on several CPUs by using the DataLoader multiprocessing functionality. The number of implicated processes is specified with the num_workers argument.

The persistent_workers=True argument enables the processes to remain active throughout the training, thereby avoiding their reinitialisation at every epoch. In counterpart, this time gain implies a potentially consequent usage of the RAM memory, particularly if more than one DataLoader is used.

Complete example of optimised data loading

Here is a complete example of the optimised loading of the ImageNet database for a distributed learning on Jean Zay:

import torch
import torchvision
import idr_torch # IDRIS package available in all PyTorch modules
 
# define list of transformations to apply
data_transform = torchvision.transforms.Compose([torchvision.transforms.Resize((300,300)),
                                                 torchvision.transforms.ToTensor()])
 
# load imagenet dataset and apply transformations
root = os.environ['DSDIR']+'/imagenet/RawImages'
imagenet_dataset = torchvision.datasets.ImageNet(root=root,
                                                 transform=data_transform)
 
# define distributed sampler
data_sampler = torch.utils.data.distributed.DistributedSampler(imagenet_dataset,
                                                               shuffle=True,
                                                               num_replicas=idr_torch.size,
                                                               rank=idr_torch.rank
                                                              )
 
# define DataLoader
batch_size = 128                       # adjust batch size according to GPU type (16GB or 32GB in memory)
drop_last = True                       # set to False if it represents important information loss
num_workers = 4                        # adjust number of CPU workers per process
persistent_workers = True              # set to False if CPU RAM must be released
pin_memory = True                      # optimize CPU to GPU transfers
non_blocking = True                    # activate asynchronism to speed up CPU/GPU transfers
prefetch_factor = 2                    # adjust number of batches to preload
 
dataloader = torch.utils.data.DataLoader(imagenet_dataset,
                                         sampler=data_sampler,
                                         batch_size=batch_size,
                                         drop_last=drop_last,
                                         num_workers=num_workers,
                                         persistent_workers=persistent_workers,
                                         pin_memory=pin_memory,
                                         prefetch_factor=prefetch_factor
                                        )
 
# loop over batches
for i, (images, labels) in enumerate(dataloader):
    images = images.to(gpu, non_blocking=non_blocking)
    labels = labels.to(gpu, non_blocking=non_blocking)

Implementation on Jean Zay

In order to implement the above documentation and obtain an idea of the gains brought by each of the functionalities offered by the PyTorch DataLoader, you need to recover the Jupyter Notebook notebook_data_preprocessing_pytorch-eng.ipynb in the DSDIR. For example, to recover it in your WORK:

$ cp $DSDIR/examples_IA/Torch_parallel/notebook_data_preprocessing_pytorch-eng.ipynb $WORK

You may also download the Notebook here.

Then you can open and execute the Notebook from the IDRIS jupyterhub platform. To use jupyterhub, please refer to the corresponding documentations: Jean Zay: Access to JupyterHub and https://jupyterhub.idris.fr/services/documentation/.

Official documentation