PyTorch Lightning Tutorial: : Simplifying Deep Learning with PyTorch
Last Updated :
08 Oct, 2024
Pytorch-Lightning is an open source library that extends the library PyTorch. It is a useful library as it provides direct approach for training and testing loops thereby making codes simple and also reducing lines of code. This library is also used for multi GPU training, distribution training etc. Some other features of Pytorch-Lightning are as follows:
- Integration with Loggers like CSV Logger, Tensorboard Logger.
- We can use checkpoints to save model weights during training phase.
- Customize callbacks to get details of the metrics programmatically.
Setup and Installation
In this step we will simply create a PyTorch model and utilize Pytorch-Lightning for training and testing of the model. Here we have used MNIST dataset. But before that we need to install libraries using pip or conda.
pip install torch torchvision pytorch-lightning torchmetrics comet-ml
The article is thoughtfully divided into three progressive sections—Beginner, Intermediate, and Advanced tutorials—each designed to cater to varying levels of expertise and to systematically build the reader's proficiency with PyTorch-Lightning.
Beginners Tutorial: Creating a PyTorch Model with PyTorch-Lightning
The Beginner Tutorial serves as an entry point, guiding newcomers through the essential steps of setting up their environment, creating a simple convolutional neural network (CNN) using the MNIST dataset, and executing basic training and testing procedures without the complexity of manual loops. This section emphasizes understanding the foundational architecture and leveraging PyTorch-Lightning's streamlined training mechanisms to reduce code verbosity and enhance clarity.
1. Creating a Model
After installing and importing the libraries, we will create the architecture of the model. In this step we will define the structure of our model.
- The layers, activation functions are defined in this step.
- From the code we can see that we have used two Convolution layers, one MaxPool layer, two linear layers and ReLU as our activation function.
- For accuracy we have the Accuracy method of the torchmetrics library. We have also defined the forward pass method as well.
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import Accuracy # Use torchmetrics for accuracy
# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) # Conv layer 1
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # Conv layer 2
self.pool = nn.MaxPool2d(2, 2) # Max Pool layer
self.fc1 = nn.Linear(64 * 5 * 5, 128) # Adjusted Linear layer 1 (64 * 5 * 5 = 1600)
self.fc2 = nn.Linear(128, 10) # Output layer
self.accuracy = Accuracy(task='multiclass', num_classes=10) # Initialize accuracy metric
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # (batch_size, 32, 26, 26) -> (batch_size, 32, 13, 13)
x = self.pool(F.relu(self.conv2(x))) # (batch_size, 64, 11, 11) -> (batch_size, 64, 5, 5)
x = x.view(x.size(0), -1) # Flatten the tensor
x = F.relu(self.fc1(x)) # (batch_size, 1600)
x = self.fc2(x) # (batch_size, 10)
return x
2. Training and Optimizing our model
In this step we will not use any loop to train our model. This is where Pytorch-Lightning comes into play.
- We will just provide with batch, inputs for the particular batch, pass the input into our model, use loss function to calculate loss for each batch.
- We use self.log method to get logs of the training loss. Finally we provide the Adam optimizer to optimize the model weights.
Python
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def training_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('val_loss', loss)
def test_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
acc = self.accuracy(output, target) # Calculate accuracy using torchmetrics
self.log('test_loss', loss)
self.log('test_acc', acc) # Log accuracy as well
3. Preparation of the dataset
In this step we load the MNIST dataset and divide it into train, validation and test. For validation we consider about 20% of random training data. Lastly we create data loaders whose batch size is 64.
Python
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
# Dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)
# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])
# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)
4. Fit the data and test the model
In this we initialize the trainer model, fit the data and train it for 10 epochs. Then we use the test data loader to test our model performance.
Python
# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=10)
# Train the model
trainer.fit(model, train_loader, val_loader)
# Test the model
trainer.test(model, test_loader)
Full code implementation
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from torchmetrics import Accuracy # Use torchmetrics for accuracy
# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) # Conv layer 1
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # Conv layer 2
self.pool = nn.MaxPool2d(2, 2) # Max Pool layer
self.fc1 = nn.Linear(64 * 5 * 5, 128) # Adjusted Linear layer 1 (64 * 5 * 5 = 1600)
self.fc2 = nn.Linear(128, 10) # Output layer
self.accuracy = Accuracy(task='multiclass', num_classes=10) # Initialize accuracy metric
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # (batch_size, 32, 26, 26) -> (batch_size, 32, 13, 13)
x = self.pool(F.relu(self.conv2(x))) # (batch_size, 64, 11, 11) -> (batch_size, 64, 5, 5)
x = x.view(x.size(0), -1) # Flatten the tensor
x = F.relu(self.fc1(x)) # (batch_size, 1600)
x = self.fc2(x) # (batch_size, 10)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def training_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('val_loss', loss)
def test_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
acc = self.accuracy(output, target) # Calculate accuracy using torchmetrics
self.log('test_loss', loss)
self.log('test_acc', acc) # Log accuracy as well
# Dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)
# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])
# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)
# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)
# Test the model
trainer.test(model, test_loader)
Output:
PyTorch Lightning TutorialsAs we can see the test accuracy of our model is 98.7%.
Moving into the Intermediate Tutorial, the focus shifts to optimizing model performance and resource efficiency. Here, readers learn to implement mixed precision training, which balances 16-bit and 32-bit floating-point computations to accelerate training and minimize memory usage.
Additionally, this section introduces the concept of custom callbacks, allowing users to inject custom behaviors—such as printing epoch numbers—into the training loop, thereby providing greater control and flexibility over the training process.
Mixed precision training utilizes both 16-bit and 32-bit floating-point types:
- 16-bit (FP16): Reduces memory consumption and increases computational speed.
- 32-bit (FP32): Maintains model stability during weight updates.
Here also we will create a model, defined the train and test methods inside the class. Finally we initialize our model and Trainer class for training and testing purposes. We have also used custom callback to print the epoch number during the training phase.
Modify the trainer to enable mixed precision and add a custom callback to monitor epochs:
Python
trainer = pl.Trainer(max_epochs=5,precision=16, # Enable mixed precision training
callbacks=[PrintEpochCallback()] )
From the code we can see that for enabling precision training when we call the Trainer class, we just provide with the precision value.
For example here we have given 16. So basically it might happen that 16 bit floating point numbers can be used during forward pass and gradient calculations while 32 bit floating point can be used during weight updates.
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import CometLogger
# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) # Conv layer 1
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # Conv layer 2
self.pool = nn.MaxPool2d(2, 2) # Max Pool layer
self.fc1 = nn.Linear(64 * 5 * 5, 128) # Adjusted Linear layer 1
self.fc2 = nn.Linear(128, 10) # Output layer
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1) # Flatten the tensor
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def training_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('val_loss', loss)
def test_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('test_loss', loss)
# Dataset and DataLoader
# Custom Callback to print the epoch
class PrintEpochCallback(Callback):
def on_train_epoch_start(self, trainer, pl_module):
print(f"Starting Epoch: {trainer.current_epoch + 1}")
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)
# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])
# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)
# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=1,precision=16, # Enable mixed precision training
callbacks=[PrintEpochCallback()] ) # Add the custom callback)
trainer.fit(model, train_loader, val_loader)
# Test the model
trainer.test(model, test_loader)
Output:
PyTorch Lightning TutorialsAdvanced Tutorial: Integrating Comet Logger
Finally, the Advanced Tutorial delves into sophisticated integrations and experiment management techniques. It demonstrates how to incorporate external tools like Comet.ml for comprehensive experiment tracking and visualization, enabling users to log metrics, compare different training runs, and collaborate more effectively.
In this tutorial, we will monitor the training phase as we all know that Pytorch-Lightning can be integrated with many Loggers like Tensorboard Logger, Comet Logger. Here we will be using Comet Logger to log our metrics and visualize them interactively. It also helps us to keep track of hyperparameters and also provides with charts thereby reducing code complexity.
Setting Up Comet.ml:
- Sign Up: Create an account at Comet.ml.
- Obtain API Key: Navigate to Settings to find your API key.
- Install Comet-ML: Ensure it's installed via pip:
After signing up to Comet.ml in order to get the API key and the workspace name. A workspace gets created by default. Under the workspace consists of the list of projects. Also we need to install comet-ml using pip or conda command.
pip install comet-ml
- Now we will create a model, configure the optimizers, define the forward pass and provide with the train, validation and test methods.
- Then we will create dataloaders and lastly initialize the Comet logger with the API Key, workspace name and project name.
Python
# Initialize CometLogger
comet_logger = CometLogger(
api_key="API", # Replace with your Comet API key
project_name="mnist-classification",
workspace="WORKSPACE_NAME" # Replace with your workspace name
)
- CometLogger: Captures and logs metrics automatically when using
self.log
in the model. - Visualization: Access interactive dashboards on Comet.ml to monitor training progress, compare experiments, and analyze hyperparameters.
Below is the full implementation of the code
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import CometLogger
# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) # Conv layer 1
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # Conv layer 2
self.pool = nn.MaxPool2d(2, 2) # Max Pool layer
self.fc1 = nn.Linear(64 * 5 * 5, 128) # Adjusted Linear layer 1
self.fc2 = nn.Linear(128, 10) # Output layer
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1) # Flatten the tensor
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def training_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('val_loss', loss)
def test_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('test_loss', loss)
# Dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)
# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])
# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)
# Initialize CometLogger
comet_logger = CometLogger(
api_key="sbMMY0ClIkTR7QoREyRBFP3Ju", # Replace with your Comet API key
project_name="mnist-classification",
workspace="baidehi1874" # Replace with your workspace name
)
# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=5, logger=comet_logger)
trainer.fit(model, train_loader, val_loader)
# Test the model
trainer.test(model, test_loader)
Output:
Benefits of Using Comet Logger:
- Real-Time Tracking: Monitor training metrics in real-time through interactive dashboards.
- Experiment Management: Compare different runs, track hyperparameters, and maintain reproducibility.
- Collaboration: Share experiments and results with team members seamlessly.
Conclusion
PyTorch-Lightning significantly simplifies the PyTorch workflow by abstracting complex training loops, enabling advanced features with minimal code changes, and integrating seamlessly with various tools for logging and monitoring. Whether you're a beginner aiming to build and train models efficiently or an advanced practitioner looking to optimize and monitor large-scale experiments, PyTorch-Lightning offers robust solutions to enhance your deep learning projects.
Similar Reads
Train a Deep Learning Model With Pytorch Neural Network is a type of machine learning model inspired by the structure and function of human brain. It consists of layers of interconnected nodes called neurons which process and transmit information. Neural networks are particularly well-suited for tasks such as image and speech recognition,
6 min read
Saving and Loading Weights in PyTorch Lightning In Machine learning models, it is important to save and load weights efficiently. This helps us preserve the state of our model during training, so we can resume later without starting from scratch. In this article, we are going to discuss how to save and load weights in PyTorch Lightning. PyTorch L
8 min read
PyTorch vs PyTorch Lightning The PyTorch research team at Facebook AI Research (FAIR) introduced PyTorch Lightning to address these challenges and provide a more organized and standardized approach. In this article, we will see the major differences between PyTorch Lightning and Pytorch. Table of Content PytorchPytorch Lightnin
9 min read
PyTorch Lightning with TensorBoard Pytorch-Lightning is a popular deep learning framework. It basically works with PyTorch models to simplify the training and testing of the models. This library is useful for distributed training as one can train the model seamlessly without much complex codes. Now to get the metrics in an user inter
5 min read
Monitoring Model Training in PyTorch with Callbacks and Logging Monitoring model training is crucial for understanding the performance and behavior of your machine learning models. PyTorch provides several mechanisms to facilitate this, including the use of callbacks and logging. This article will guide you through the process of using these tools effectively. T
7 min read
Performing Batch Multiplication in PyTorch Without Using torch.bmm Batch multiplication is a fundamental operation in deep learning and scientific computing, especially when working with large datasets and models. PyTorch, a popular deep learning framework, provides several methods for matrix multiplication, including torch.bmm for batch matrix multiplication. Howe
5 min read
Understanding PyTorch Learning Rate Scheduling In the realm of deep learning, PyTorch stands as a beacon, illuminating the path for researchers and practitioners to traverse the complex landscapes of artificial intelligence. Its dynamic computational graph and user-friendly interface have solidified its position as a preferred framework for deve
8 min read
Difference Between detach() and with torch.no_grad() in PyTorch In PyTorch, managing gradients is crucial for optimizing models and ensuring efficient computations. Two commonly used methods to control gradient tracking are detach() and with torch.no_grad(). Understanding the differences between these two approaches is essential for effectively managing computat
6 min read
Clearing GPU Memory After PyTorch Training Without Kernel Restart Managing GPU memory effectively is crucial when training deep learning models using PyTorch, especially when working with limited resources or large models. This article will guide you through various techniques to clear GPU memory after PyTorch model training without restarting the kernel. We will
4 min read
Converting a List of Tensors to a Single Tensor in PyTorch PyTorch, a popular deep learning framework, provides powerful tools for tensor manipulation. One common task in PyTorch is converting a list of tensors into a single tensor. This operation is crucial for various applications, including data preprocessing, model input preparation, and tensor operatio
4 min read