Implement Deep Autoencoder in PyTorch for Image Reconstruction
Last Updated :
13 Jul, 2021
Since the availability of staggering amounts of data on the internet, researchers and scientists from industry and academia keep trying to develop more efficient and reliable data transfer modes than the current state-of-the-art methods. Autoencoders are one of the key elements found in recent times used for such a task with their simple and intuitive architecture.
Broadly, once an autoencoder is trained, the encoder weights can be sent to the transmitter side and the decoder weights to the receiver side. This way, the transmitter side can send data in an encoded format(thus saving them time and money) while the receiver side can receive the data at much less overhaul. This article will explore an interesting application of autoencoder, which can be used for image reconstruction on the famous MNIST digits dataset using the Pytorch framework in Python.
Autoencoders
As shown in the figure below, a very basic autoencoder consists of two main parts:
- An Encoder and,
- A Decoder
Through a series of layers, the encoder takes the input and takes the higher dimensional data to the latent low dimension representation of the same values. The decoder takes this latent representation and outputs the reconstructed data.
For a deeper understanding of the theory, the reader is encouraged to go through the following article: ML | Auto-Encoders

A basic 2 layer Autoencoder
Installation:
Aside from the usual libraries like Numpy and Matplotlib, we only need the torch and torchvision libraries from the Pytorch toolchain for this article. You can use the following command to get all these libraries.
pip3 install torch torchvision torchaudio numpy matplotlib
Now onto the most interesting part, the code. The article assumes a basic familiarity with the PyTorch workflow and its various utilities, like Dataloaders, Datasets and Tensor transforms. For a quick refresher of these concepts, the reader is encouraged to go through the following articles:
The code is divided into 5 different steps for a better flow of the material and is to be executed sequentially for proper work. Each step also has some points at its start, which can help the reader better understand that step’s code.
Stepwise implementation:
Step 1: Loading data and printing some sample images from the training set.
- Initializing Transform: Firstly, we initialize the transform which would be applied to each entry in the attained dataset. Since Tensors are internal to Pytorch’s functioning, we first convert each item to a tensor and normalize them to limit the pixel values between 0 & 1. This is done to make the optimization process easier and faster.
- Downloading Dataset: Then, we download the dataset using the torchvision.datasets utility and store it on our local machine in the folder ./MNIST/train and ./MNIST/test for both training and testing sets. We also convert these datasets into data loaders with batch sizes equal to 256 for faster learning. The reader is encouraged to play around with these values and expect consistent results.
- Plotting Dataset: Lastly, we randomly print out 25 images from the dataset to better view the data we’re dealing with.
Code:
Python
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torch
plt.rcParams[ 'figure.figsize' ] = 15 , 10
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(( 0.5 ), ( 0.5 ))
])
train_dataset = torchvision.datasets.MNIST(
root = "./MNIST/train" , train = True ,
transform = torchvision.transforms.ToTensor(),
download = True )
test_dataset = torchvision.datasets.MNIST(
root = "./MNIST/test" , train = False ,
transform = torchvision.transforms.ToTensor(),
download = True )
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size = 256 )
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size = 256 )
random_samples = np.random.randint(
1 , len (train_dataset), ( 25 ))
for idx in range (random_samples.shape[ 0 ]):
plt.subplot( 5 , 5 , idx + 1 )
plt.imshow(train_dataset[idx][ 0 ][ 0 ].numpy(), cmap = 'gray' )
plt.title(train_dataset[idx][ 1 ])
plt.axis( 'off' )
plt.tight_layout()
plt.show()
|
Output:

Random samples from the training set
Step 2: Initializing the Deep Autoencoder model and other hyperparameters
In this step, we initialize our DeepAutoencoder class, a child class of the torch.nn.Module. This abstracts away a lot of boilerplate code for us, and now we can focus on building our model architecture which is as follows:

Model Architecture
As described above, the encoder layers form the first half of the network, i.e., from Linear-1 to Linear-7, and the decoder forms the other half from Linear-10 to Sigmoid-15. We’ve used the torch.nn.Sequential utility for separating the encoder and decoder from one another. This was done to give a better understanding of the model’s architecture. After that, we initialize some model hyperparameters such that the training is done for 100 epochs using the Mean Square Error loss and Adam optimizer for the learning process.
Python
class DeepAutoencoder(torch.nn.Module):
def __init__( self ):
super ().__init__()
self .encoder = torch.nn.Sequential(
torch.nn.Linear( 28 * 28 , 256 ),
torch.nn.ReLU(),
torch.nn.Linear( 256 , 128 ),
torch.nn.ReLU(),
torch.nn.Linear( 128 , 64 ),
torch.nn.ReLU(),
torch.nn.Linear( 64 , 10 )
)
self .decoder = torch.nn.Sequential(
torch.nn.Linear( 10 , 64 ),
torch.nn.ReLU(),
torch.nn.Linear( 64 , 128 ),
torch.nn.ReLU(),
torch.nn.Linear( 128 , 256 ),
torch.nn.ReLU(),
torch.nn.Linear( 256 , 28 * 28 ),
torch.nn.Sigmoid()
)
def forward( self , x):
encoded = self .encoder(x)
decoded = self .decoder(encoded)
return decoded
model = DeepAutoencoder()
criterion = torch.nn.MSELoss()
num_epochs = 100
optimizer = torch.optim.Adam(model.parameters(), lr = 1e - 3 )
|
Step 3: Training loop
The training loop iterates for the 100 epochs and does the following things:
- Iterates over each batch and calculates loss between the outputted image and the original image(which is the output).
- Averages out the loss for each batch and stores images and their outputs for each epoch.
After the loop ends, we plot out the training loss to better understand the training process. As we can see, that the loss decreases for each consecutive epoch, and thus the training can be deemed successful.
Python
train_loss = []
outputs = {}
batch_size = len (train_loader)
for epoch in range (num_epochs):
running_loss = 0
for batch in train_loader:
img, _ = batch
img = img.reshape( - 1 , 28 * 28 )
out = model(img)
loss = criterion(out, img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss + = loss.item()
running_loss / = batch_size
train_loss.append(running_loss)
outputs[epoch + 1 ] = { 'img' : img, 'out' : out}
plt.plot( range ( 1 ,num_epochs + 1 ),train_loss)
plt.xlabel( "Number of epochs" )
plt.ylabel( "Training Loss" )
plt.show()
|
Output:

Training loss vs. Epochs
Step 4: Visualizing the reconstruction
The best part of this project is that the reader can visualize the reconstruction of each epoch and understand the iterative learning of the model.
- We firstly plot out the first 5 reconstructed(or outputted images) for epochs = [1, 5, 10, 50, 100].
- Then we also plot the corresponding original images on the bottom for comparison.
We can see how the reconstruction improves for each epoch and gets very close to the original by the last epoch.
Python
counter = 1
epochs_list = [ 1 , 5 , 10 , 50 , 100 ]
for val in epochs_list:
temp = outputs[val][ 'out' ].detach().numpy()
title_text = f "Epoch = {val}"
for idx in range ( 5 ):
plt.subplot( 7 , 5 , counter)
plt.title(title_text)
plt.imshow(temp[idx].reshape( 28 , 28 ), cmap = 'gray' )
plt.axis( 'off' )
counter + = 1
for idx in range ( 5 ):
val = outputs[ 10 ][ 'img' ]
plt.subplot( 7 , 5 ,counter)
plt.imshow(val[idx].reshape( 28 , 28 ),
cmap = 'gray' )
plt.title( "Original Image" )
plt.axis( 'off' )
counter + = 1
plt.tight_layout()
plt.show()
|
Output:

Visualizing the reconstruction from the data collected during the training process
Step 5: Checking performance on the test set.
Good practice in machine learning is to check the model’s performance on the test set also. To do that, we do the following steps:
- Generate outputs for the last batch of the test set.
- Plot the first 10 outputs and corresponding original images for comparison.
As we can see, the reconstruction was excellent on this test set also, which completes the pipeline.
Python
outputs = {}
img, _ = list (test_loader)[ - 1 ]
img = img.reshape( - 1 , 28 * 28 )
out = model(img)
outputs[ 'img' ] = img
outputs[ 'out' ] = out
counter = 1
val = outputs[ 'out' ].detach().numpy()
for idx in range ( 10 ):
plt.subplot( 2 , 10 , counter)
plt.title( "Reconstructed \n image" )
plt.imshow(val[idx].reshape( 28 , 28 ), cmap = 'gray' )
plt.axis( 'off' )
counter + = 1
for idx in range ( 10 ):
val = outputs[ 'img' ]
plt.subplot( 2 , 10 , counter)
plt.imshow(val[idx].reshape( 28 , 28 ), cmap = 'gray' )
plt.title( "Original Image" )
plt.axis( 'off' )
counter + = 1
plt.tight_layout()
plt.show()
|
Output:

Verifying performance on the test set
Conclusion:
Autoencoders are fast becoming one of the most exciting areas of research in machine learning. This article covered the Pytorch implementation of a deep autoencoder for image reconstruction. The reader is encouraged to play around with the network architecture and hyperparameters to improve the reconstruction quality and the loss values.
Similar Reads
Perceptual Autoencoder: Enhancing Image Reconstruction with Deep Learning
In recent years, autoencoders have emerged as powerful tools in unsupervised learning, especially in image compression and reconstruction. The Perceptual Autoencoder is a specialized type of autoencoder that takes image reconstruction to the next level by optimizing for pixel-wise accuracy and perce
15 min read
Implement Convolutional Autoencoder in PyTorch with CUDA
Autoencoders are a type of neural network architecture used for unsupervised learning tasks such as data compression, dimensionality reduction, and data denoising. The architecture consists of two main components: an encoder and a decoder. The encoder portion of the network compresses the input data
4 min read
Implementing an Autoencoder in PyTorch
Autoencoders are neural networks that learn to compress and reconstruct data. In this guide weâll walk you through building a simple autoencoder in PyTorch using the MNIST dataset. This approach is useful for image compression, denoising and feature extraction. Implementation of Autoencoder in PyTor
4 min read
How to Adjust Saturation of an image in PyTorch?
In this article, we are going to discuss How to adjust the saturation of an image in PyTorch. adjust_saturation() method Saturation is basically used to adjust the intensity of the color of the given Image, we can adjust the saturation of an image by using the adjust_saturation() method of torchvisi
2 min read
Overcomplete Autoencoders with PyTorch
Neural networks are used in autoencoders to encode and decode data. They are utilized in many different applications, including data compression, natural language processing, and picture and audio recognition. Autoencoders work by learning a compressed representation of the input data that may be us
7 min read
Disentanglement in Beta Variational Autoencoders
Beta Variational Autoencoders was proposed by researchers at Deepmind in 2017. It was accepted in the International Conference on Learning Representations (ICLR) 2017. Before learning Beta- variational autoencoder, please check out this article for variational autoencoder. If in variational autoenco
4 min read
How to crop an image at center in PyTorch?
In this article, we will discuss how to crop an image at the center in PyTorch. CenterCrop() method We can crop an image in PyTorch by using the CenterCrop() method. This method accepts images like PIL Image, Tensor Image, and a batch of Tensor images. The tensor image is a PyTorch tensor with [C, H
2 min read
Difference between detach, clone, and deepcopy in PyTorch tensors
In PyTorch, managing tensors efficiently while ensuring correct gradient propagation and data manipulation is crucial in deep learning workflows. Three important operations that deal with tensor handling in PyTorch are detach(), clone(), and deepcopy(). Each serves a unique purpose when working with
6 min read
Sparse Autoencoders in Deep Learning
Sparse autoencoders are a specific form of autoencoder that's been trained for feature learning and dimensionality reduction. As opposed to regular autoencoders, which are trained to reconstruct the input data in the output, sparse autoencoders add a sparsity penalty that encourages the hidden layer
5 min read
How to convert an image to grayscale in PyTorch
In this article, we are going to see how to convert an image to grayscale in PyTorch. torchvision.transforms.grayscale method Grayscaling is the process of converting an image from other color spaces e.g. RGB, CMYK, HSV, etc. to shades of gray. It varies between complete black and complete white. to
2 min read