Open In App

Image Classification using ResNet

Last Updated : 13 Feb, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

This article will walk you through the steps to implement it for image classification using Python and TensorFlow/Keras.

Image classification classifies an image into one of several predefined categories. ResNet (Residual Networks), which introduced the concept of residual connections to address the vanishing gradient problem in very deep neural networks.

Here are the key reasons to use ResNet for image classification:

  • Enables Deeper Networks: ResNet makes it possible to train networks with hundreds or even thousands of layers without performance degradation.
  • Improved Performance: By using residual learning, ResNet achieves better accuracy in tasks like image classification.
  • Better Generalization: The architecture helps avoid overfitting, improving model performance on unseen data.

Image Classification Using ResNet on CIFAR-10

Here’s a step-by-step guide to implement image classification using the CIFAR-10 dataset and ResNet50 in TensorFlow:

1. Import Libraries

We begin by importing the necessary libraries from TensorFlow and Keras:

Python
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical

2. Load and Preprocess the CIFAR-10 Dataset

We load the CIFAR-10 dataset using tensorflow.keras.datasets.cifar10. Then, we normalize the pixel values of the images (by dividing by 255) to scale them to a range of 0 to 1. Lastly, we one-hot encode the labels to match the output format for categorical classification.

Python
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Preprocess the data
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# One-hot encode the labels
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

3. Load ResNet50 Pre-trained on ImageNet

We use ResNet50, pre-trained on the ImageNet dataset. The include_top=False parameter ensures that the fully connected layers (the classification head) are not included, so we can add our custom layers.

Python
base_model = ResNet50(weights='imagenet', 
                      include_top=False, 
                      input_shape=(32, 32, 3))

# Freeze the base model
base_model.trainable = False

4. Build the Classification Model

We now build the model using the pre-trained ResNet50 as a base. We add a GlobalAveragePooling2D layer to reduce the dimensions of the feature maps from the ResNet base model, followed by a Dense layer for classification.

The final layer has 10 neurons, one for each class in the CIFAR-10 dataset, with a softmax activation function.

Python
# Build the classification model
model = Sequential([
    base_model,
    GlobalAveragePooling2D(),
    Dense(1024, activation='relu'),
    Dense(10, activation='softmax')  
])

5. Compile the Model

We use the Adam optimizer with a small learning rate to prevent overfitting and use categorical cross-entropy as the loss function for multi-class classification. We also track the accuracy metric during training.

Python
# Compile the model
model.compile(optimizer=Adam(learning_rate=0.0001), 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])

6. Train the Model

We then train the model on the CIFAR-10 training data, using a batch size of 64 and 10 epochs. We also pass the test data for validation during training to monitor the model’s performance.

Python
# Train the model
model.fit(x_train, y_train, 
          batch_size=64, 
          epochs=10, 
          validation_data=(x_test, y_test))

7. Evaluate the Model

Once the model is trained, we evaluate it on the test data to check its accuracy.

Python
# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

Output:

Test accuracy: 0.8741999864578247

ResNet's residual connections enable us to train very deep models, and its pre-trained weights, when fine-tuned for specific tasks, can provide remarkable accuracy even with smaller datasets. By freezing the early layers of the model, we can focus on learning the final decision-making layers, which is ideal for many real-world applications in image classification.


Next Article

Similar Reads