使用ONNX将模型转移至Caffe2和移动端
在本教程中,我们将介绍如何使用 ONNX 将 PyTorch 中定义的模型转换为 ONNX 格式,然后将其加载到 Caffe2 中。一旦进入 Caffe2,我们就可以运行模型来仔细检查它是否正确导出,然后我们展示了如何使用 Caffe2 功能(如移动导出器)在移动设备上执行模型。
在本教程中,您需要安装onnx和Caffe2。您可以使用pip install onnx来获取 onnx。
注意:本教程需要 PyTorch master 分支,可以按照这里说明进行安装。
1.引入模型
# 一些包的导入 import io import numpy as np from torch import nn import torch.utils.model_zoo as model_zoo import torch.onnx
1.1 `SuperResolution`模型
超分辨率是一种提高图像、视频分辨率的方法,广泛用于图像处理或视频剪辑。在本教程中,我们将首先使用带有虚拟输入的小型超分辨率模型。
首先,让我们在 PyTorch 中创建一个SuperResolution模型。这个模型直接来自 PyTorch 的例子,没有修改:
# PyTorch中定义的Super Resolution模型 import torch.nn as nn import torch.nn.init as init class SuperResolutionNet(nn.Module): def __init__(self, upscale_factor, inplace=False): super(SuperResolutionNet, self).__init__() self.relu = nn.ReLU(inplace=inplace) self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self._initialize_weights() def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.pixel_shuffle(self.conv4(x)) return x def _initialize_weights(self): init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv4.weight) # 使用上面模型定义,创建super-resolution模型 torch_model = SuperResolutionNet(upscale_factor=3