Understanding PyTorch Lightning DataModules
Last Updated :
08 Dec, 2020
PyTorch Lightning aims to make PyTorch code more structured and readable and that not just limited to the PyTorch Model but also the data itself. In PyTorch we use DataLoaders to train or test our model. While we can use DataLoaders in PyTorch Lightning to train the model too, PyTorch Lightning also provides us with a better approach called DataModules. DataModule is a reusable and shareable class that encapsulates the DataLoaders along with the steps required to process data. Creating dataloaders can get messy that's why it's better to club the dataset in the form of DataModule. Its recommended that you know how to define a neural network using PyTorch Lightning.
Installing PyTorch Lightning:
Installing Lightning is the same as that of any other library in python.
pip install pytorch-lightning
Or if you want to install it in a conda environment you can use the following command:-
conda install -c conda-forge pytorch-lightning
Pytorch Lightning DataModule Format
To define a Lightning DataModule we follow the following format:-
import pytorch-lightning as pl
from torch.utils.data import random_split, DataLoader
class DataModuleClass(pl.LightningDataModule):
def __init__(self):
#Define required parameters here
def prepare_data(self):
# Define steps that should be done
# on only one GPU, like getting data.
def setup(self, stage=None):
# Define steps that should be done on
# every GPU, like splitting data, applying
# transform etc.
def train_dataloader(self):
# Return DataLoader for Training Data here
def val_dataloader(self):
# Return DataLoader for Validation Data here
def test_dataloader(self):
# Return DataLoader for Testing Data here
Note: The names of the above functions should be exactly the same.
Understanding the DataModule Class
For this article, I'll be using MNIST data as an example. As we can see, the first requirement to create a Lightning DataModule is to inherit the LightningDataModule class in pytorch-lightning:
import pytorch-lightning as pl
from torch.utils.data import random_split, DataLoader
class DataModuleMNIST(pl.LightningDataModule):
__init__() method:
It is used to store information regarding batch size, transforms, etc.
def __init__(self):
super().__init__()
self.download_dir = ''
self.batch_size = 32
self.transform = transforms.Compose([
transforms.ToTensor()
])
prepare_data() method:
This method is used to define the processes that are meant to be performed by only one GPU. It's usually used to handle the task of downloading the data.
def prepare_data(self):
datasets.MNIST(self.download_dir,
train=True, download=True)
datasets.MNIST(self.download_dir, train=False,
download=True)
setup() method:
This method is used to define the process that is meant to be performed by all the available GPU. It's usually used to handle the task of loading the data.
def setup(self, stage=None):
data = datasets.MNIST(self.download_dir,
train=True, transform=self.transform)
self.train_data, self.valid_data = random_split(data, [55000, 5000])
self.test_data = datasets.MNIST(self.download_dir,
train=False, transform=self.transform)
train_dataloader() method:
This method is used to create a training data dataloader. In this function, you usually just return the dataloader of training data.
def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size)
val_dataloader() method:
This method is used to create a validation data dataloader. In this function, you usually just return the dataloader of validation data.
def val_dataloader(self):
return DataLoader(self.valid_data, batch_size=self.batch_size)
test_dataloader() method:
This method is used to create a testing data dataloader. In this function, you usually just return the dataloader of testing data.
def test_dataloader(self):
return DataLoader(self.test_data, batch_size=self.batch_size)
Training Pytorch Lightning Model Using DataModule:
In Pytorch Lighting, we use Trainer() to train our model and in this, we can pass the data as DataLoader or DataModule. Let's use the model I defined in this article here as an example:
class model(pl.LightningModule):
def __init__(self):
super(model, self).__init__()
self.fc1 = nn.Linear(28*28, 256)
self.fc2 = nn.Linear(256, 128)
self.out = nn.Linear(128, 10)
self.lr = 0.01
self.loss = nn.CrossEntropyLoss()
def forward(self, x):
batch_size, _, _, _ = x.size()
x = x.view(batch_size, -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.out(x)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.lr)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.loss(logits, y)
return loss
def validation_step(self, valid_batch, batch_idx):
x, y = valid_batch
logits = self.forward(x)
loss = self.loss(logits, y)
Now to train this model we'll create a Trainer() object and fit() it by passing our model and datamodules as parameters.
clf = model()
mnist = DataModuleMNIST()
trainer = pl.Trainer(gpus=1)
trainer.fit(clf, mnist)
Below the full implementation:
Python3
# import module
import torch
# To get the layers and losses for our model
from torch import nn
import pytorch_lightning as pl
# To get the activation function for our model
import torch.nn.functional as F
# To get MNIST data and transforms
from torchvision import datasets, transforms
# To get the optimizer for our model
from torch.optim import SGD
# To get random_split to split training
# data into training and validation data
# and DataLoader to create dataloaders for train,
# valid and test data to be returned
# by our data module
from torch.utils.data import random_split, DataLoader
class model(pl.LightningModule):
def __init__(self):
super(model, self).__init__()
# Defining our model architecture
self.fc1 = nn.Linear(28*28, 256)
self.fc2 = nn.Linear(256, 128)
self.out = nn.Linear(128, 10)
# Defining learning rate
self.lr = 0.01
# Defining loss
self.loss = nn.CrossEntropyLoss()
def forward(self, x):
# Defining the forward pass of the model
batch_size, _, _, _ = x.size()
x = x.view(batch_size, -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.out(x)
def configure_optimizers(self):
# Defining and returning the optimizer for our model
# with the defines parameters
return torch.optim.SGD(self.parameters(), lr = self.lr)
def training_step(self, train_batch, batch_idx):
# Defining training steps for our model
x, y = train_batch
logits = self.forward(x)
loss = self.loss(logits, y)
return loss
def validation_step(self, valid_batch, batch_idx):
# Defining validation steps for our model
x, y = valid_batch
logits = self.forward(x)
loss = self.loss(logits, y)
class DataModuleMNIST(pl.LightningDataModule):
def __init__(self):
super().__init__()
# Directory to store MNIST Data
self.download_dir = ''
# Defining batch size of our data
self.batch_size = 32
# Defining transforms to be applied on the data
self.transform = transforms.Compose([
transforms.ToTensor()
])
def prepare_data(self):
# Downloading our data
datasets.MNIST(self.download_dir,
train = True, download = True)
datasets.MNIST(self.download_dir,
train = False, download = True)
def setup(self, stage=None):
# Loading our data after applying the transforms
data = datasets.MNIST(self.download_dir,
train = True,
transform = self.transform)
self.train_data, self.valid_data = random_split(data,
[55000, 5000])
self.test_data = datasets.MNIST(self.download_dir,
train = False,
transform = self.transform)
def train_dataloader(self):
# Generating train_dataloader
return DataLoader(self.train_data,
batch_size = self.batch_size)
def val_dataloader(self):
# Generating val_dataloader
return DataLoader(self.valid_data,
batch_size = self.batch_size)
def test_dataloader(self):
# Generating test_dataloader
return DataLoader(self.test_data,
batch_size = self.batch_size)
clf = model()
mnist = DataModuleMNIST()
trainer = pl.Trainer()
trainer.fit(clf, mnist)
Output:
Similar Reads
Understanding torch.nn.Parameter
PyTorch is a widely used library for building and training neural networks, and understanding its components is key to effectively using it for machine learning tasks. One of the essential classes in PyTorch is torch.nn.Parameter, which plays a crucial role in defining trainable parameters within a
5 min read
Image Classification Using PyTorch Lightning
Image classification is one of the most common tasks in computer vision and involves assigning a label to an input image from a predefined set of categories. While PyTorch is a powerful deep learning framework, PyTorch Lightning builds on it to simplify model training, reduce boilerplate code, and i
4 min read
PyTorch-Lightning Conda Setup Guide
PyTorch-Lightning is a popular deep learning framework and is more simple version of PyTorch. It is easy to use as one does not need to define the training loops and the testing loops. We can perform distributed training easily without making the code complex. Some other features include more focus
7 min read
PyTorch Lightning Multi Dataloader Guide
PyTorch Lightning provides a streamlined interface for managing multiple dataloaders, which is essential for handling complex datasets and training scenarios. This guide will explore the various methods and best practices for using multiple dataloaders in PyTorch Lightning, covering everything from
4 min read
How to Install PyTorch Lightning
PyTorch Lightning is a powerful and flexible framework designed to streamline the process of building complex deep learning models using PyTorch. By organizing PyTorch code, it allows researchers and engineers to focus more on research and less on boilerplate code. This article will guide you throug
2 min read
Understanding File Extensions in PyTorch: .pt, .pth, and .pwf
PyTorch, a widely-used machine learning library, offers various ways to save and load models. Among the file extensions commonly associated with PyTorch are .pt, .pth, and .pwf. These extensions serve different purposes and have unique characteristics. This article delves into the distinctions betwe
5 min read
PyTorch Lightning with TensorBoard
Pytorch-Lightning is a popular deep learning framework. It basically works with PyTorch models to simplify the training and testing of the models. This library is useful for distributed training as one can train the model seamlessly without much complex codes. Now to get the metrics in an user inter
5 min read
Saving and Loading Weights in PyTorch Lightning
In Machine learning models, it is important to save and load weights efficiently. This helps us preserve the state of our model during training, so we can resume later without starting from scratch. In this article, we are going to discuss how to save and load weights in PyTorch Lightning. PyTorch L
8 min read
PyTorch vs PyTorch Lightning
The PyTorch research team at Facebook AI Research (FAIR) introduced PyTorch Lightning to address these challenges and provide a more organized and standardized approach. In this article, we will see the major differences between PyTorch Lightning and Pytorch. Table of Content PytorchPytorch Lightnin
9 min read
What Does model.train() Do in PyTorch?
A crucial aspect of training a model in PyTorch involves setting the model to the correct mode, either training or evaluation. This article delves into the purpose and functionality of the model.train() method in PyTorch, explaining its significance in the training process and how it interacts with
4 min read