Open In App

PyTorch Lightning Tutorial: : Simplifying Deep Learning with PyTorch

Last Updated : 08 Oct, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

Pytorch-Lightning is an open source library that extends the library PyTorch. It is a useful library as it provides direct approach for training and testing loops thereby making codes simple and also reducing lines of code. This library is also used for multi GPU training, distribution training etc. Some other features of Pytorch-Lightning are as follows:

  • Integration with Loggers like CSV Logger, Tensorboard Logger.
  • We can use checkpoints to save model weights during training phase.
  • Customize callbacks to get details of the metrics programmatically.

Setup and Installation

In this step we will simply create a PyTorch model and utilize Pytorch-Lightning for training and testing of the model. Here we have used MNIST dataset. But before that we need to install libraries using pip or conda.

pip install torch torchvision pytorch-lightning torchmetrics comet-ml

The article is thoughtfully divided into three progressive sections—Beginner, Intermediate, and Advanced tutorials—each designed to cater to varying levels of expertise and to systematically build the reader's proficiency with PyTorch-Lightning.

Beginners Tutorial: Creating a PyTorch Model with PyTorch-Lightning

The Beginner Tutorial serves as an entry point, guiding newcomers through the essential steps of setting up their environment, creating a simple convolutional neural network (CNN) using the MNIST dataset, and executing basic training and testing procedures without the complexity of manual loops. This section emphasizes understanding the foundational architecture and leveraging PyTorch-Lightning's streamlined training mechanisms to reduce code verbosity and enhance clarity.

1. Creating a Model

After installing and importing the libraries, we will create the architecture of the model. In this step we will define the structure of our model.

  • The layers, activation functions are defined in this step.
  • From the code we can see that we have used two Convolution layers, one MaxPool layer, two linear layers and ReLU as our activation function.
  • For accuracy we have the Accuracy method of the torchmetrics library. We have also defined the forward pass method as well.
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import Accuracy  # Use torchmetrics for accuracy

# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  # Conv layer 1
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)  # Conv layer 2
        self.pool = nn.MaxPool2d(2, 2)  # Max Pool layer
        self.fc1 = nn.Linear(64 * 5 * 5, 128)  # Adjusted Linear layer 1 (64 * 5 * 5 = 1600)
        self.fc2 = nn.Linear(128, 10)  # Output layer

        self.accuracy = Accuracy(task='multiclass', num_classes=10)  # Initialize accuracy metric

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # (batch_size, 32, 26, 26) -> (batch_size, 32, 13, 13)
        x = self.pool(F.relu(self.conv2(x)))  # (batch_size, 64, 11, 11) -> (batch_size, 64, 5, 5)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))  # (batch_size, 1600)
        x = self.fc2(x)  # (batch_size, 10)
        return x

2. Training and Optimizing our model

In this step we will not use any loop to train our model. This is where Pytorch-Lightning comes into play.

  • We will just provide with batch, inputs for the particular batch, pass the input into our model, use loss function to calculate loss for each batch.
  • We use self.log method to get logs of the training loss. Finally we provide the Adam optimizer to optimize the model weights.
Python
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        acc = self.accuracy(output, target)  # Calculate accuracy using torchmetrics
        self.log('test_loss', loss)
        self.log('test_acc', acc)  # Log accuracy as well

3. Preparation of the dataset

In this step we load the MNIST dataset and divide it into train, validation and test. For validation we consider about 20% of random training data. Lastly we create data loaders whose batch size is 64.

Python
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# Dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)

# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])

# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)

4. Fit the data and test the model

In this we initialize the trainer model, fit the data and train it for 10 epochs. Then we use the test data loader to test our model performance.

Python
# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=10)

# Train the model
trainer.fit(model, train_loader, val_loader)

# Test the model
trainer.test(model, test_loader)

Full code implementation

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from torchmetrics import Accuracy  # Use torchmetrics for accuracy

# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  # Conv layer 1
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)  # Conv layer 2
        self.pool = nn.MaxPool2d(2, 2)  # Max Pool layer
        self.fc1 = nn.Linear(64 * 5 * 5, 128)  # Adjusted Linear layer 1 (64 * 5 * 5 = 1600)
        self.fc2 = nn.Linear(128, 10)  # Output layer

        self.accuracy = Accuracy(task='multiclass', num_classes=10)  # Initialize accuracy metric

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # (batch_size, 32, 26, 26) -> (batch_size, 32, 13, 13)
        x = self.pool(F.relu(self.conv2(x)))  # (batch_size, 64, 11, 11) -> (batch_size, 64, 5, 5)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))  # (batch_size, 1600)
        x = self.fc2(x)  # (batch_size, 10)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        acc = self.accuracy(output, target)  # Calculate accuracy using torchmetrics
        self.log('test_loss', loss)
        self.log('test_acc', acc)  # Log accuracy as well

# Dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)

# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])

# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)

# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)

# Test the model
trainer.test(model, test_loader)

Output:

Screenshot-2024-09-29-104927
PyTorch Lightning Tutorials

As we can see the test accuracy of our model is 98.7%.

Intermediate Tutorial: Mixed Precision Training

Moving into the Intermediate Tutorial, the focus shifts to optimizing model performance and resource efficiency. Here, readers learn to implement mixed precision training, which balances 16-bit and 32-bit floating-point computations to accelerate training and minimize memory usage.

Additionally, this section introduces the concept of custom callbacks, allowing users to inject custom behaviors—such as printing epoch numbers—into the training loop, thereby providing greater control and flexibility over the training process.

Mixed precision training utilizes both 16-bit and 32-bit floating-point types:

  • 16-bit (FP16): Reduces memory consumption and increases computational speed.
  • 32-bit (FP32): Maintains model stability during weight updates.

Here also we will create a model, defined the train and test methods inside the class. Finally we initialize our model and Trainer class for training and testing purposes. We have also used custom callback to print the epoch number during the training phase.

Modify the trainer to enable mixed precision and add a custom callback to monitor epochs:

Python
trainer = pl.Trainer(max_epochs=5,precision=16,  # Enable mixed precision training
    callbacks=[PrintEpochCallback()] )

From the code we can see that for enabling precision training when we call the Trainer class, we just provide with the precision value.

For example here we have given 16. So basically it might happen that 16 bit floating point numbers can be used during forward pass and gradient calculations while 32 bit floating point can be used during weight updates.

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import CometLogger

# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  # Conv layer 1
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)  # Conv layer 2
        self.pool = nn.MaxPool2d(2, 2)  # Max Pool layer
        self.fc1 = nn.Linear(64 * 5 * 5, 128)  # Adjusted Linear layer 1
        self.fc2 = nn.Linear(128, 10)  # Output layer

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log('test_loss', loss)

# Dataset and DataLoader
# Custom Callback to print the epoch
class PrintEpochCallback(Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        print(f"Starting Epoch: {trainer.current_epoch + 1}")

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)

# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])

# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)


# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=1,precision=16,  # Enable mixed precision training
    callbacks=[PrintEpochCallback()] ) # Add the custom callback) 
trainer.fit(model, train_loader, val_loader)

# Test the model
trainer.test(model, test_loader)

Output:

Screenshot-2024-09-29-112326
PyTorch Lightning Tutorials

Advanced Tutorial: Integrating Comet Logger

Finally, the Advanced Tutorial delves into sophisticated integrations and experiment management techniques. It demonstrates how to incorporate external tools like Comet.ml for comprehensive experiment tracking and visualization, enabling users to log metrics, compare different training runs, and collaborate more effectively.

In this tutorial, we will monitor the training phase as we all know that Pytorch-Lightning can be integrated with many Loggers like Tensorboard Logger, Comet Logger. Here we will be using Comet Logger to log our metrics and visualize them interactively. It also helps us to keep track of hyperparameters and also provides with charts thereby reducing code complexity.

Setting Up Comet.ml:

  • Sign Up: Create an account at Comet.ml.
  • Obtain API Key: Navigate to Settings to find your API key.
  • Install Comet-ML: Ensure it's installed via pip:

After signing up to Comet.ml in order to get the API key and the workspace name. A workspace gets created by default. Under the workspace consists of the list of projects. Also we need to install comet-ml using pip or conda command.

pip install comet-ml

Screenshot-2024-09-29-095408
  • Now we will create a model, configure the optimizers, define the forward pass and provide with the train, validation and test methods.
  • Then we will create dataloaders and lastly initialize the Comet logger with the API Key, workspace name and project name.
Python
# Initialize CometLogger
comet_logger = CometLogger(
    api_key="API",  # Replace with your Comet API key
    project_name="mnist-classification",
    workspace="WORKSPACE_NAME"  # Replace with your workspace name
)
  • CometLogger: Captures and logs metrics automatically when using self.log in the model.
  • Visualization: Access interactive dashboards on Comet.ml to monitor training progress, compare experiments, and analyze hyperparameters.

Below is the full implementation of the code

Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import CometLogger

# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)  # Conv layer 1
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)  # Conv layer 2
        self.pool = nn.MaxPool2d(2, 2)  # Max Pool layer
        self.fc1 = nn.Linear(64 * 5 * 5, 128)  # Adjusted Linear layer 1
        self.fc2 = nn.Linear(128, 10)  # Output layer

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log('test_loss', loss)

# Dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)

# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])

# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)

# Initialize CometLogger
comet_logger = CometLogger(
    api_key="sbMMY0ClIkTR7QoREyRBFP3Ju",  # Replace with your Comet API key
    project_name="mnist-classification",
    workspace="baidehi1874"  # Replace with your workspace name
)

# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=5, logger=comet_logger)
trainer.fit(model, train_loader, val_loader)

# Test the model
trainer.test(model, test_loader)

Output:


Benefits of Using Comet Logger:

  • Real-Time Tracking: Monitor training metrics in real-time through interactive dashboards.
  • Experiment Management: Compare different runs, track hyperparameters, and maintain reproducibility.
  • Collaboration: Share experiments and results with team members seamlessly.

Conclusion

PyTorch-Lightning significantly simplifies the PyTorch workflow by abstracting complex training loops, enabling advanced features with minimal code changes, and integrating seamlessly with various tools for logging and monitoring. Whether you're a beginner aiming to build and train models efficiently or an advanced practitioner looking to optimize and monitor large-scale experiments, PyTorch-Lightning offers robust solutions to enhance your deep learning projects.


Next Article

Similar Reads