pytorch自身部署较麻烦,一般使用onnx和mnn较为实用
训练模型的代码:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.onnx
if __name__ == '__main__':
# transforms
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# datasets
trainset = torchvision.datasets.FashionMNIST('data',download=True,train=True, transform=transform)
testset = torchvision.datasets.FashionMNIST('data',download=True,train=False,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
shuffle=False, num_workers=2)
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
print('trainloader--------',trainloader)
class FashionMNIST(nn.Module):
def __init__(self,num_classes=10):
super(FashionMNIST,self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2)
self.drop = nn.Dropout(0.2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(in_features=64* 3 * 3,

这篇博客介绍了如何使用PyTorch训练一个FashionMNIST分类模型,然后将其导出为ONNX格式,并进一步转换为MNN格式,以适应移动端的部署。过程中详细阐述了每个步骤,包括数据预处理、模型定义、训练、模型保存、ONNX导出以及ONNX和MNN模型的运行与验证。
最低0.47元/天 解锁文章
9142

被折叠的 条评论
为什么被折叠?



