Open In App

Convert Pytorch model to tf-lite with onnx-tf

Last Updated : 07 Oct, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

The increasing demand for deploying machine learning models on mobile and edge devices has led to the necessity of converting models into formats that are optimized for such environments. TensorFlow Lite (TFLite) is one such format that is widely used for deploying models on mobile devices.

The different frameworks can be integrated in different environments. For instance if we have a model developed on Keras and want to use Raspberry Pi. But Raspberry Pi supports PyTorch better. Instead of developing the model from scratch using PyTorch library we can convert our model to PyTorch and meet our requirements accordingly. This article provides a detailed guide on converting PyTorch models to TFLite using ONNX and TensorFlow as intermediate steps.

Understanding the Conversion Workflow

The conversion process from PyTorch to TFLite involves several steps, utilizing ONNX (Open Neural Network Exchange) as a bridge between PyTorch and TensorFlow. The workflow can be summarized as follows:

  1. PyTorch to ONNX: Export the PyTorch model to the ONNX format.
  2. ONNX to TensorFlow: Convert the ONNX model to a TensorFlow model.
  3. TensorFlow to TFLite: Finally, convert the TensorFlow model to TFLite format.

Exporting a PyTorch Model to ONNX

ONNX or Open Neural Network Exchange is a format that is used to express the architecture of deep learning models. It acts as an intermediate especially when we need to convert the model from one framework to another.

Here we have used MNIST dataset and built a model using two linear layers. The model is trained, optimized, loss is calculated. Then we provide dummy input so that the model can understand the type of data that will be processed when it will be loaded again. Finally we export the model to onnx format along with additional arguments and operators.

Python
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
import numpy as np

# Step 2: Define the simple CNN model for MNIST
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        x = torch.relu(self.conv2(x))
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Step 3: Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Step 4: Initialize and train the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training the model for 10 epochs
for epoch in range(10):  # Train for 10 epochs
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f'Epoch [{epoch+1}/10], Loss: {total_loss/len(train_loader):.4f}')

# Step 5: Export the model to ONNX format
dummy_input = torch.randn(1, 1, 28, 28).to(device)  # Example input
onnx_filename = "mnist_model.onnx"
torch.onnx.export(model, dummy_input, onnx_filename, input_names=['input'], output_names=['output'], opset_version=11)

torch.onnx.export method is responsible for exporting the PyTorch model to ONNX format.

Installing and Setting up ONNX-TF

ONNX-TF is a converter that is used to convert the ONNX models to Tensorflow models and vice-versa. Here we have used Python 3.8 as there is version compatibility issues in later versions of Python. Here we have to install some libraries with specific versions. These versions can be installed using pip command. Below are the versions of the packages, so that there is no during the setup of ONNX-TF.

onnx 1.16.2

onnx-tf 1.10.0

numpy 1.23.5

python 3.8.15

tensorflow 2.10.0

tensorflow-addons 0.21.0

tensorflow-probability 0.13.0

Converting ONNX Model to TensorFlow

To convert the ONNX model to TensorFlow, we first need to load our model using load method. Then the function prepare is used to convert the ONNX model to TensorFlow model. Then the results are saved using the method export_graph.

Python
onnx_model = onnx.load(onnx_filename)
tf_rep = prepare(onnx_model)

tf_model_dir = "./tf_model"
tf_rep.export_graph(tf_model_dir)

Optimizing and Converting TensorFlow Model to TensorFlow Lite

Optimizing means reducing the size of the model. This is done so that the models can be deployed on mobiles and there is no compromise with the quality of the performance. Since TensorFlow Lite is the lightweight version of TensorFlow we will optimize it more so that we can get much better results. In this step we will be converting our TensorFlow model to TensorFlow Lite and also an optimization technique called quantization. Quantization basically reduces the precision of floating point numbers.

Python
# Step 8: Convert the TensorFlow model to TensorFlow Lite with optimization (quantization)
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_dir)

# Apply post-training quantization to reduce model size
converter.optimizations = [tf.lite.Optimize.DEFAULT]  # Default quantization


# Convert the model
tflite_model = converter.convert()

# Step 9: Save the optimized TensorFlow Lite model
with open("mnist_model_optimized.tflite", "wb") as f:
    f.write(tflite_model)

print("Optimized TensorFlow Lite model saved successfully!")

Here we have used default optimization to optimize the weights and used convert method to convert our model. Lastly we saved the weights in .tflite format.

Testing the TensorFlow Lite Model

In this step we have loaded the tflite model and used the MNIST testing dataset. The dataset is first converted to Numpy format. Then the data is passed through the interpreter which is basically our Tensorflow lite model. Finally we calculate the accuracy of the model.

Python
interpreter = tf.lite.Interpreter(model_path="mnist_model_optimized.tflite")
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Evaluate the TensorFlow Lite model
correct = 0
total = 0

# Loop through test data
for images, labels in test_loader:
    # Convert PyTorch tensor to numpy
    images_np = images.numpy()

    for i in range(len(images_np)):
        input_data = np.expand_dims(images_np[i], axis=0).astype(np.float32)  # Adjust shape for TFLite

        # Set the tensor to the input data
        interpreter.set_tensor(input_details[0]['index'], input_data)

        # Invoke the interpreter
        interpreter.invoke()

        # Get output predictions
        output_data = interpreter.get_tensor(output_details[0]['index'])
        prediction = np.argmax(output_data)

        # Compare prediction to the true label
        if prediction == labels[i].item():
            correct += 1
        total += 1

# Print the accuracy of the optimized TensorFlow Lite model
accuracy = correct / total
print(f'Optimized TensorFlow Lite Model Accuracy: {accuracy * 100:.2f}%')

Output:


The accuracy of the model is 98.86%. Many warnings will occur but it can be ignored.

Troubleshooting Common Issues in Model Conversion

There are many common issues during conversion of the model from one framework to another. Some of them are as follows:

  • Input shapes for different vectors in different frameworks might mismatch. To avoid this issue always provide a dummy input during the conversion or reshape the layers.
  • Version compatibility is another issue. In the latest versions of Python there is version issue between TensorFlow and ONNX. To avoid this issue use earlier versions of Python like Python 3.8 and provide the versions which are compatible to one another.
  • Many operations are available in PyTorch but not in TensorFlow or ONNX and vice-versa. To avoid this problem, replace the complex operations with more simple ones.

Conclusion

Model Conversion is a useful technique as it helps to deploy wide variety of models on different frameworks. These frameworks can be optimized making it suitable to deploy on mobiles and other devices.


Next Article
Practice Tags :

Similar Reads