ResNet18 from Scratch Using PyTorch
Last Updated :
21 Apr, 2025
ResNet18 is a variant of the Residual Network (ResNet) architecture, which was introduced to address the vanishing gradient problem in deep neural networks. The architecture is designed to allow networks to be deeper, thus improving their ability to learn complex patterns in data.
This article will guide you through the process of implementing ResNet18 from scratch using PyTorch, covering the theoretical background, implementation details, and training the model.
Understanding ResNet Architecture
The ResNet18 model consists of 18 layers, including convolutional layers and residual blocks. Residual blocks are the core component of ResNet architectures, and they include skip connections that bypass one or more layers.
A residual block allows the input to bypass one or more layers via a shortcut connection, which helps in mitigating the vanishing gradient problem. This is achieved by learning the residual mapping instead of the original mapping. A typical residual block in ResNet18 consists of two convolutional layers with batch normalization and ReLU activation.
The shortcut connection adds the input of the block to the output of the second convolutional layer, allowing gradients to flow through the network more effectively. The main idea is to allow the model to learn residual mappings instead of learning the original unreferenced mapping.
Implementing ResNet18 from Scratch
PyTorch is a very easy-to-use framework. It is very flexible to use. It allows us to define and train neural networks without worrying too much about the underlying mathematical operations.
We have to first install Python libraries before we start coding. We have to first set up our environment:
pip install torch torchvision numpy
- We will use PyTorch for building the model, torchvision for datasets and transformations, and numpy for basic array operations.
- Next, we will define the class that will contain the convolutional, batch normalization, and ReLU layers that make up a single ResNet block
- Define the ResNet18 class: This class will put together the BasicBlocks to create the full ResNet18 model and it will also include the initial convolutional layer, max pooling layer, and fully connected layer at the end.
- Train the model on a dataset: We will use the CIFAR10 dataset in this example. The training process will involve forward passes through the model, calculating losses, and updating the model parameters using backpropagation
To create ResNet18, we start with two main parts. Below is a simplified version of the code to show how it works:
Building the Basic Block And ResNet18 architecture
- Defining BasicBlock: The core structure of ResNet, with convolution, batch normalization, ReLU activation, and shortcut (residual) connections.
- Building ResNet18: A full implementation with four layers of BasicBlocks, where the number of channels increases and downsampling is performed.
- Initial Layers: Includes the first convolution, batch normalization, ReLU activation, and max pooling.
- Residual Layers: Four main layers, each with multiple residual blocks.
- Pooling and Classification: Adaptive average pooling followed by a fully connected layer for classification.
- Forward Propagation: Defined how data flows through the network to produce the final prediction.
Python
import torch.nn as nn
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += self.shortcut(x)
out = self.relu(out)
return out
class ResNet18(nn.Module):
def __init__(self, num_classes=10):
super(ResNet18, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.maxpool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
Output:
Building the Basic BlockTraining Loop Implementation for ResNet18 Architecture:
- An initial convolutional layer to process the input image.
- Max pooling to reduce the size of the data while keeping important features.
- Several layers made up of our BasicBlocks, which help the model learn at different levels of complexity.
- A final fully connected layer that outputs the predictions.
To train the model, we would need to create an instance of the ResNet18 class, define the loss function and optimizer, and then run the training loop. Below is simplified example:
Python
import torch.optim as optim
import torch.nn.functional as F
model = ResNet18()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch} loss: {running_loss / len(trainloader)}')
Output:
Training Loop Implementation for ResNet18 ArchitectureAbove code creates an instance of the ResNet18 model, defines the loss function and optimizer, and then runs the training loop for 10 epochs.
- For each batch of data, the model makes predictions and calculates how far off those predictions are from the actual labels (this is the loss).
- The model then adjusts its internal settings (weights) based on the loss to improve its predictions for the next round.
Conclusion
Implementing ResNet18 from scratch using PyTorch provides a deeper understanding of the architecture and the benefits of residual learning. This model is a powerful tool for image classification tasks and can be further extended to more complex versions like ResNet34, ResNet50, etc. Understanding the intricacies of building and training such models is crucial for any deep learning practitioner looking to solve complex problems in
Similar Reads
Reinforcement Learning using PyTorch Reinforcement learning using PyTorch enables dynamic adjustment of agent strategies, crucial for navigating complex environments and maximizing rewards. The article aims to demonstrate how PyTorch enables the iterative improvement of RL agents by balancing exploration and exploitation to maximize re
7 min read
PyTorch for Speech Recognition Speech recognition is a transformative technology that enables computers to understand and interpret spoken language, fostering seamless interaction between humans and machines. By implementing algorithms and machine learning techniques, speech recognition systems transcribe spoken words into text,
5 min read
How to Split a Dataset Using PyTorch Splitting a dataset is an important step in training machine learning models. It helps to separate the data into different sets, typically training, and validation, so we can train our model on one set and validate its performance on another. In this article, we are going to discuss the process of s
6 min read
Visualizing Feature Maps using PyTorch Interpreting and visualizing feature maps in PyTorch is like looking at snapshots of what's happening inside a neural network as it processes information. In this Tutorial, we will walk through interpreting and visualizing feature maps in PyTorch. What are Feature Maps?Feature maps enable us to capt
6 min read
Building a Vision Transformer from Scratch in PyTorch Vision Transformers (ViTs) have revolutionized the field of computer vision by leveraging transformer architecture, which was originally designed for natural language processing. Unlike traditional CNNs, ViTs divide an image into patches and treat them as tokens, allowing the model to learn spatial
5 min read
Understanding PyTorch Learning Rate Scheduling In the realm of deep learning, PyTorch stands as a beacon, illuminating the path for researchers and practitioners to traverse the complex landscapes of artificial intelligence. Its dynamic computational graph and user-friendly interface have solidified its position as a preferred framework for deve
8 min read
How to Upsample a PyTorch Tensor? As the amount of data generated by modern sensors and simulations continues to grow, it's becoming increasingly common for datasets to include multiple channels representing different properties or dimensions. However, in some cases, these channels may be at a lower resolution or spatial/temporal sc
8 min read
Tensors in Pytorch A Pytorch Tensor is basically the same as a NumPy array. This means it does not know anything about deep learning or computational graphs or gradients and is just a generic n-dimensional array to be used for arbitrary numeric computation. However, the biggest difference between a NumPy array and a P
6 min read
Computer Vision with PyTorch PyTorch is a powerful framework applicable to various computer vision tasks. The article aims to enumerate the features and functionalities within the context of computer vision that empower developers to build neural networks and train models. It also demonstrates how PyTorch framework can be utili
6 min read
What are Torch Scripts in PyTorch? TorchScript is a powerful feature in PyTorch that allows developers to create serializable and optimizable models from PyTorch code. It serves as an intermediate representation of a PyTorch model that can be run in high-performance environments, such as C++, without the need for a Python runtime. Th
5 min read