Convert Pytorch model to tf-lite with onnx-tf
Last Updated :
07 Oct, 2024
The increasing demand for deploying machine learning models on mobile and edge devices has led to the necessity of converting models into formats that are optimized for such environments. TensorFlow Lite (TFLite) is one such format that is widely used for deploying models on mobile devices.
The different frameworks can be integrated in different environments. For instance if we have a model developed on Keras and want to use Raspberry Pi. But Raspberry Pi supports PyTorch better. Instead of developing the model from scratch using PyTorch library we can convert our model to PyTorch and meet our requirements accordingly. This article provides a detailed guide on converting PyTorch models to TFLite using ONNX and TensorFlow as intermediate steps.
Understanding the Conversion Workflow
The conversion process from PyTorch to TFLite involves several steps, utilizing ONNX (Open Neural Network Exchange) as a bridge between PyTorch and TensorFlow. The workflow can be summarized as follows:
- PyTorch to ONNX: Export the PyTorch model to the ONNX format.
- ONNX to TensorFlow: Convert the ONNX model to a TensorFlow model.
- TensorFlow to TFLite: Finally, convert the TensorFlow model to TFLite format.
Exporting a PyTorch Model to ONNX
ONNX or Open Neural Network Exchange is a format that is used to express the architecture of deep learning models. It acts as an intermediate especially when we need to convert the model from one framework to another.
Here we have used MNIST dataset and built a model using two linear layers. The model is trained, optimized, loss is calculated. Then we provide dummy input so that the model can understand the type of data that will be processed when it will be loaded again. Finally we export the model to onnx format along with additional arguments and operators.
Python
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
import numpy as np
# Step 2: Define the simple CNN model for MNIST
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
x = torch.relu(self.conv2(x))
x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
x = x.view(x.size(0), -1) # Flatten the tensor
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Step 3: Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
# Step 4: Initialize and train the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training the model for 10 epochs
for epoch in range(10): # Train for 10 epochs
model.train()
total_loss = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch [{epoch+1}/10], Loss: {total_loss/len(train_loader):.4f}')
# Step 5: Export the model to ONNX format
dummy_input = torch.randn(1, 1, 28, 28).to(device) # Example input
onnx_filename = "mnist_model.onnx"
torch.onnx.export(model, dummy_input, onnx_filename, input_names=['input'], output_names=['output'], opset_version=11)
torch.onnx.export method is responsible for exporting the PyTorch model to ONNX format.
Installing and Setting up ONNX-TF
ONNX-TF is a converter that is used to convert the ONNX models to Tensorflow models and vice-versa. Here we have used Python 3.8 as there is version compatibility issues in later versions of Python. Here we have to install some libraries with specific versions. These versions can be installed using pip command. Below are the versions of the packages, so that there is no during the setup of ONNX-TF.
onnx 1.16.2
onnx-tf 1.10.0
numpy 1.23.5
python 3.8.15
tensorflow 2.10.0
tensorflow-addons 0.21.0
tensorflow-probability 0.13.0
Converting ONNX Model to TensorFlow
To convert the ONNX model to TensorFlow, we first need to load our model using load method. Then the function prepare is used to convert the ONNX model to TensorFlow model. Then the results are saved using the method export_graph.
Python
onnx_model = onnx.load(onnx_filename)
tf_rep = prepare(onnx_model)
tf_model_dir = "./tf_model"
tf_rep.export_graph(tf_model_dir)
Optimizing and Converting TensorFlow Model to TensorFlow Lite
Optimizing means reducing the size of the model. This is done so that the models can be deployed on mobiles and there is no compromise with the quality of the performance. Since TensorFlow Lite is the lightweight version of TensorFlow we will optimize it more so that we can get much better results. In this step we will be converting our TensorFlow model to TensorFlow Lite and also an optimization technique called quantization. Quantization basically reduces the precision of floating point numbers.
Python
# Step 8: Convert the TensorFlow model to TensorFlow Lite with optimization (quantization)
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_dir)
# Apply post-training quantization to reduce model size
converter.optimizations = [tf.lite.Optimize.DEFAULT] # Default quantization
# Convert the model
tflite_model = converter.convert()
# Step 9: Save the optimized TensorFlow Lite model
with open("mnist_model_optimized.tflite", "wb") as f:
f.write(tflite_model)
print("Optimized TensorFlow Lite model saved successfully!")
Here we have used default optimization to optimize the weights and used convert method to convert our model. Lastly we saved the weights in .tflite format.
Testing the TensorFlow Lite Model
In this step we have loaded the tflite model and used the MNIST testing dataset. The dataset is first converted to Numpy format. Then the data is passed through the interpreter which is basically our Tensorflow lite model. Finally we calculate the accuracy of the model.
Python
interpreter = tf.lite.Interpreter(model_path="mnist_model_optimized.tflite")
interpreter.allocate_tensors()
# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Evaluate the TensorFlow Lite model
correct = 0
total = 0
# Loop through test data
for images, labels in test_loader:
# Convert PyTorch tensor to numpy
images_np = images.numpy()
for i in range(len(images_np)):
input_data = np.expand_dims(images_np[i], axis=0).astype(np.float32) # Adjust shape for TFLite
# Set the tensor to the input data
interpreter.set_tensor(input_details[0]['index'], input_data)
# Invoke the interpreter
interpreter.invoke()
# Get output predictions
output_data = interpreter.get_tensor(output_details[0]['index'])
prediction = np.argmax(output_data)
# Compare prediction to the true label
if prediction == labels[i].item():
correct += 1
total += 1
# Print the accuracy of the optimized TensorFlow Lite model
accuracy = correct / total
print(f'Optimized TensorFlow Lite Model Accuracy: {accuracy * 100:.2f}%')
Output:
The accuracy of the model is 98.86%. Many warnings will occur but it can be ignored.
Troubleshooting Common Issues in Model Conversion
There are many common issues during conversion of the model from one framework to another. Some of them are as follows:
- Input shapes for different vectors in different frameworks might mismatch. To avoid this issue always provide a dummy input during the conversion or reshape the layers.
- Version compatibility is another issue. In the latest versions of Python there is version issue between TensorFlow and ONNX. To avoid this issue use earlier versions of Python like Python 3.8 and provide the versions which are compatible to one another.
- Many operations are available in PyTorch but not in TensorFlow or ONNX and vice-versa. To avoid this problem, replace the complex operations with more simple ones.
Conclusion
Model Conversion is a useful technique as it helps to deploy wide variety of models on different frameworks. These frameworks can be optimized making it suitable to deploy on mobiles and other devices.
Similar Reads
How to Convert a TensorFlow Model to PyTorch?
The landscape of deep learning is rapidly evolving. While TensorFlow and PyTorch stand as two of the most prominent frameworks, each boasts its unique advantages and ecosystems. However, transitioning between these frameworks can be daunting, often requiring tedious reimplementation and adaptation o
6 min read
Convert PyTorch Tensor to Python List
PyTorch, a widely-used open-source machine learning library, is known for its flexibility and ease of use in building deep learning models. A fundamental component of PyTorch is the tensor, a multi-dimensional array that serves as the primary data structure for model training and inference. However,
3 min read
How to convert an image to grayscale in PyTorch
In this article, we are going to see how to convert an image to grayscale in PyTorch. torchvision.transforms.grayscale method Grayscaling is the process of converting an image from other color spaces e.g. RGB, CMYK, HSV, etc. to shades of gray. It varies between complete black and complete white. t
2 min read
How to Print the Model Summary in PyTorch
Printing a model summary is a crucial step in understanding the architecture of a neural network. In frameworks like Keras, this is straightforward with the model.summary() method. However, in PyTorch, achieving a similar output requires a bit more work. This article will guide you through the proce
6 min read
How to convert torch tensor to pandas dataframe?
When working with deep learning models in PyTorch, you often deal with tensors. However, there are situations where you may need to convert these tensors into a Pandas DataFrame, especially when you're preparing data for analysis or visualization. In this article, we'll explore how to convert a PyTo
5 min read
Converting a List of Tensors to a Single Tensor in PyTorch
PyTorch, a popular deep learning framework, provides powerful tools for tensor manipulation. One common task in PyTorch is converting a list of tensors into a single tensor. This operation is crucial for various applications, including data preprocessing, model input preparation, and tensor operatio
4 min read
How to deploy PyTorch models on Vertex AI
PyTorch is a freely available machine learning library that can be imported and used inside the code for performing machine learning operations based on requirements. The front-end api is written in Python and the tensor operations are implemented using C++. It is developed by Facebook's AI Research
12 min read
Monitoring Model Training in PyTorch with Callbacks and Logging
Monitoring model training is crucial for understanding the performance and behavior of your machine learning models. PyTorch provides several mechanisms to facilitate this, including the use of callbacks and logging. This article will guide you through the process of using these tools effectively. T
7 min read
Converting a Pandas DataFrame to a PyTorch Tensor
PyTorch is a powerful deep learning framework widely used for building and training neural networks. One of the essential steps in using PyTorch is converting data from various formats into tensors, which are the fundamental data structures used by PyTorch. Pandas DataFrames are a common data struct
5 min read
How to Install Pytorch on MacOS?
PyTorch is an open-source machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, primarily developed by Facebook's AI Research lab. It is free and open-source software released under the Modified BSD license. Prerequisites:
2 min read