How to handle overfitting in PyTorch models using Early Stopping
Last Updated :
10 Dec, 2024
Overfitting is a challenge in machine learning, where a model performs well on training data but poorly on unseen data, due to learning excessive noise or details from the training dataset.
In the context of deep learning with PyTorch, one effective method to combat overfitting is implementing early stopping. This article explains how early stopping works, demonstrates how to implement it in PyTorch, and explores its benefits and considerations.
What is Early Stopping?
Early stopping is a regularization technique used to avoid overfitting during the training process. It involves stopping the training phase if the model's performance on a validation set does not improve for a specified number of consecutive epochs, called the "patience" period. This ensures the model does not learn the noise and specific details of the training data, thereby enhancing its generalization capabilities.
Benefits of Early Stopping
- Prevents Overfitting: By halting training at the right time, early stopping ensures the model does not overfit.
- Saves Time and Resources: It reduces unnecessary training time and computational resources by stopping the training early.
- Optimizes Model Performance: Helps in selecting the version of the model that performs best on unseen data.
Steps needed to Implement Early Stopping in PyTorch
In this section, we are going to walk through the process of creating, training, and evaluating a simple neural network using PyTorch, focusing on the implementation of early stopping to prevent overfitting.
Step 1: Import Libraries
First, we import the necessary libraries:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
Step 2: Define the Neural Network Architecture
Next, we define a simple neural network class using PyTorch's nn.Module
. The neural network has:
- fc1, fc2, fc3: Fully connected layers with ReLU activations.
- forward method: Defines the forward pass of the network
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
Step 3: Implement Early Stopping
We implement an EarlyStopping
class to halt training if the validation loss stops improving. Here the parameters are:
- patience: Number of epochs to wait before stopping if no improvement.
- delta: Minimum change in the monitored quantity to qualify as an improvement.
- best_score, best_model_state: Track the best validation score and model state.
- call method: Updates the early stopping logic.
class EarlyStopping:
def __init__(self, patience=5, delta=0):
self.patience = patience
self.delta = delta
self.best_score = None
self.early_stop = False
self.counter = 0
self.best_model_state = None
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.best_model_state = model.state_dict()
elif score < self.best_score + self.delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.best_model_state = model.state_dict()
self.counter = 0
def load_best_model(self, model):
model.load_state_dict(self.best_model_state)
Step 4: Load the Data
We load and transform the MNIST dataset.
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
Step 5: Initialize the Model, Loss Function, and Optimizer
We set up the model, criterion, and optimizer.
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
Step 6: Train the Model with Early Stopping
We train the model, incorporating early stopping.
Here,
- Train loop: Train the model, update weights, and calculate training loss.
- Validation loop: Evaluate the model on validation data and calculate validation loss.
- Early stopping check: Apply early stopping logic after each epoch.
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
model.train()
train_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)
train_loss /= len(train_loader.dataset)
# Validation step (using validation set, not test set)
model.eval()
val_loss = 0
with torch.no_grad():
for data, target in val_loader: # Changed from test_loader to val_loader
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)
val_loss /= len(val_loader.dataset)
print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break
early_stopping.load_best_model(model)
Step 7: Evaluate the Model
Finally, we evaluate the model's accuracy on the test dataset. The evaluation loop computes the accuracy by comparing predicted labels with true labels.
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')
Building and Training a Simple Neural Network with Early Stopping in PyTorch
Python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
class EarlyStopping:
def __init__(self, patience=5, delta=0):
self.patience = patience
self.delta = delta
self.best_score = None
self.early_stop = False
self.counter = 0
self.best_model_state = None
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.best_model_state = model.state_dict()
elif score < self.best_score + self.delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.best_model_state = model.state_dict()
self.counter = 0
def load_best_model(self, model):
model.load_state_dict(self.best_model_state)
# Data loading
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Split the training dataset into training and validation sets
train_size = int(0.8 * len(train_dataset)) # 80% for training
val_size = len(train_dataset) - train_size # 20% for validation
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# Model, loss function, and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Early stopping
early_stopping = EarlyStopping(patience=5, delta=0.01)
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
model.train()
train_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)
train_loss /= len(train_loader.dataset)
# Validation step (using validation set, not test set)
model.eval()
val_loss = 0
with torch.no_grad():
for data, target in val_loader:
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)
val_loss /= len(val_loader.dataset)
print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping")
break
# Load the best model
early_stopping.load_best_model(model)
# Final evaluation on the test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')
Output:
Epoch 1, Train Loss: 0.4373, Val Loss: 0.2750
Epoch 2, Train Loss: 0.2244, Val Loss: 0.1835
Epoch 3, Train Loss: 0.1617, Val Loss: 0.1441
.
.
.
Epoch 14, Train Loss: 0.0445, Val Loss: 0.1036
Epoch 15, Train Loss: 0.0398, Val Loss: 0.1205
Epoch 16, Train Loss: 0.0388, Val Loss: 0.0934
Early stopping
Accuracy of the model on the test images: 97.35%
Conclusion
In this tutorial, we demonstrated how to build, train, and evaluate a simple neural network using PyTorch, with a focus on implementing early stopping to prevent overfitting. This approach helps achieve better generalization by halting training when the validation performance stops improving.
Similar Reads
How to handle overfitting in computer vision models?
Overfitting is a common problem in machine learning, especially in computer vision tasks where models can easily memorize training data instead of learning to generalize from it. Handling overfitting is crucial to ensure that the model performs well on unseen data. In this article, we are going to e
7 min read
Using Early Stopping to Reduce Overfitting in Neural Networks
Overfitting is a common challenge in training neural networks. It occurs when a model learns to memorize the training data rather than generalize patterns from it, leading to poor performance on unseen data. While various regularization techniques like dropout and weight decay can help combat overfi
7 min read
How to handle overfitting in TensorFlow models?
Overfitting occurs when a machine learning model learns to perform well on the training data but fails to generalize to new, unseen data. In TensorFlow models, overfitting typically manifests as high accuracy on the training dataset but lower accuracy on the validation or test datasets. This phenome
10 min read
How K-Fold Prevents overfitting in a model?
In machine learning, accurately processing how well a model performs and whether it can handle new data is crucial. Yet, with limited data or concerns about generalization, traditional methods of evaluation may not cut it. That's where cross-validation steps in. It's a method that rigorously tests p
9 min read
Identifying Overfitting in Machine Learning Models Using Scikit-Learn
Overfitting is a critical issue in machine learning that can significantly impact the performance of models when applied to new, unseen data. Identifying overfitting in machine learning models is crucial to ensuring their performance generalizes well to unseen data. In this article, we'll explore ho
7 min read
How to set up and Run CUDA Operations in Pytorch ?
CUDA(or Compute Unified Device Architecture) is a proprietary parallel computing platform and programming model from NVIDIA. Using the CUDA SDK, developers can utilize their NVIDIA GPUs(Graphics Processing Units), thus enabling them to bring in the power of GPU-based parallel processing instead of t
4 min read
How to deploy PyTorch models on Vertex AI
PyTorch is a freely available machine learning library that can be imported and used inside the code for performing machine learning operations based on requirements. The front-end api is written in Python and the tensor operations are implemented using C++. It is developed by Facebook's AI Research
12 min read
How to Split a Dataset Using PyTorch
Splitting a dataset is an important step in training machine learning models. It helps to separate the data into different sets, typically training, and validation, so we can train our model on one set and validate its performance on another. In this article, we are going to discuss the process of s
6 min read
Create Model using Custom Module in Pytorch
Custom module in Pytorch A custom module in PyTorch is a user-defined module that is built using the PyTorch library's built-in neural network module, torch.nn.Module. It's a way of creating new modules by combining and extending the functionality provided by existing PyTorch modules. The torch.nn.M
8 min read
Monitoring Model Training in PyTorch with Callbacks and Logging
Monitoring model training is crucial for understanding the performance and behavior of your machine learning models. PyTorch provides several mechanisms to facilitate this, including the use of callbacks and logging. This article will guide you through the process of using these tools effectively. T
7 min read