Chargement de bases de données pour l'apprentissage distribué en PyTorch
Dans cette page, nous mettons en pratique la gestion des Datasets et DataLoaders pour l'apprentissage distribué en PyTorch. Nous nous intéressons aux problématiques présentées dans la page principale sur le chargement des données.
Nous présentons ici l'usage :
- des Datasets (prédéfinis et personnalisés)
- des outils de transformations des données d'entrée (prédéfinis et personnalisés)
- des DataLoaders
La documentation se conclut sur la présentation d'un exemple complet de chargement optimisé de données, et sur une mise en pratique sur Jean Zay via un Jupyter Notebook.
Dans cette documentation, nous ne parlerons pas des objets de type IterableDataset qui permettent de traiter des bases de données dont la structure est inconnue. Ce genre d'objet est parcouru à l'aide d'un itérateur dont le mécanisme se réduit à acquérir l'élément suivant (si il existe). Ce mécanisme empêche l'utilisation directe de certaines fonctionnalités mentionnées dans la section « DataLoader », comme le shuffling et le mutiprocessing, qui se basent sur des manipulations d'indices et ont besoin d'une vision globale de la base de données.
Datasets
Datasets prédéfinis dans PyTorch
PyTorch propose un ensemble de Datasets prédéfinis dans les librairies torchvision, torchaudio et torchtext. Ces librairies gèrent la création d'un objet Dataset pour des bases de données standards listées dans les documentations officielles :
- liste des Datasets prédéfinis dans torchvision
- liste des Datasets prédéfinis dans torchaudio
- liste des Datasets prédéfinis dans torchtext
Le chargement d'une base données se fait via le module Datasets. Par exemple, le chargement de la base de données d'images ImageNet peut se faire avec torchvision de la manière suivante :
import torchvision
# load imagenet dataset stored in DSDIR
root = os.environ['DSDIR']+'/imagenet'
imagenet_dataset = torchvision.datasets.ImageNet(root=root)
La plupart du temps, il est possible de différencier au chargement les données dédiées à l'entraînement des données dédiées à la validation. Par exemple, pour la base ImageNet :
import torchvision
# load imagenet dataset stored in DSDIR
root = os.environ['DSDIR']+'/imagenet'
## load data for training
imagenet_train_dataset = torchvision.datasets.ImageNet(root=root,
split='train')
## load data for validation
imagenet_val_dataset = torchvision.datasets.ImageNet(root=root,
split='val')
Chaque fonction de chargement propose ensuite des fonctionnalités spécifiques aux bases de données (qualité des données, extraction d'une sous-partie des données, etc). Nous vous invitons à consulter les documentations officielles pour plus de détails.
La librairie torchvision contient une fonction générique de chargement torchvision.Datasets.ImageFolder. Elle est adaptée à toute base de données d’images, sous condition que celle-ci soit stockée dans un certain format (voir la documentation officielle pour plus de détails).
Certaines fonctions proposent de télécharger les bases données en ligne grâce à l’argument download=True. Nous vous rappelons que les nœuds de calcul Jean Zay n’ont pas accès à internet et que de telles opérations doivent se faire en amont depuis une frontale ou un nœud de pré/post-traitement. Nous vous rappelons également que des bases de données publiques et volumineuses sont déjà disponibles sur l’espace commun DSDIR de Jean Zay. Cet espace peut-être enrichi sur demande auprès de l’assistance IDRIS (assist@idris.fr).
Datasets personnalisés
Il est possible de créer ses propres classes Datasets en définissant trois fonctions caractéristiques :
__init__initialise la variable contenant les données à traiter__len__retourne la longueur de la base de données__getitem__retourne la donnée correspondant à un indice donnée
Par exemple :
class myDataset(Dataset):
def __init__(self, data):
# Initialise dataset from source dataset
self.data = data
def __len__(self):
# Return length of the dataset
return len(self.data)
def __getitem__(self, idx):
# Return one element of the dataset according to its index
return self.data[idx]
Transformations
Transformations prédéfinies dans PyTorch
Les librairies torchvision, torchtext et torchaudio offrent un panel de transformations pré-implémentées, accessibles via le module transforms de la classe Datasets. Ces transformations sont listées dans les documentations officielles :
- liste des transformations prédéfinies dans torchvision
- liste des transformations prédéfinies dans torchaudio
- liste des transformations prédéfinies dans torchtext
Les instructions de transformation sont portées par l'objet Dataset. Il est possible de cumuler différents types de transformations grâce à la fonction transforms.Compose(). Par exemple, pour redimensionner l'ensemble des images de la base de données ImageNet :
import torchvision
# define list of transformations to apply
data_transform = torchvision.transforms.Compose([torchvision.transforms.Resize((300,300)),
torchvision.transforms.ToTensor()])
# load imagenet dataset and apply transformations
root = os.environ['DSDIR']+'/imagenet'
imagenet_dataset = torchvision.datasets.ImageNet(root=root,
transform=data_transform)
La transformation transforms.ToTensor() permet de convertir une image PIL ou un tableau NumPy en tenseur.
Pour appliquer des transformations sur un Dataset personnalisé, il faut modifier celui-ci en conséquence, par exemple de la manière suivante :
class myDataset(Dataset):
def __init__(self, data, transform=None):
# Initialise dataset from source dataset
self.data = data
self.transform = transform
def __len__(self):
# Return length of the dataset
return len(self.data)
def __getitem__(self, idx):
# Return one element of the dataset according to its index
x = self.data[idx]
# apply transformation if requested
if self.transform:
x = self.transform(x)
return x
Transformations personnalisées
Il est aussi possible de créer ses propres transformations en définissant des fonctions callable et en les communiquant directement à transforms.Compose(). On peut par exemple définir des transformations de type somme (Add) et multiplication (Mult) de la manière suivante :
# define Add tranformation
class Add(object):
def __init__(self, value):
self.value = value
def __call__(self, sample):
# add a constant to the data
return sample + self.value
# define Mult transformation
class Mult(object):
def __init__(self, value):
self.value = value
def __call__(self, sample):
# multiply the data by a constant
return sample * self.value
# define list of transformations to apply
data_transform = transforms.Compose([Add(2),Mult(3)])
DataLoaders
Un objet DataLoader est une sur-couche d'un objet Dataset qui permet de structurer les données (création de batches), de les pré-traiter (shuffling, transformations) et de les diffuser aux GPU pour la phase d'entraînement.
Le DataLoader est un objet de la classe torch.utils.data.DataLoader :
import torch
# define DataLoader for a given dataset
dataloader = torch.utils.data.DataLoader(dataset)