onnx重写输入和输出的维度

onnx模型输入是静态的,比如是1x3x960x960,但是想把输入改成动态输入,相应的输出也得改成动态,以下代码可以修改onnx模型的维度:

import onnx
import onnx.checker
import onnx.utils
from onnx.tools import update_model_dims
 
 
model = onnx.load('infer_rec.onnx')
# 此处可以理解为获得了一个维度 “引用”,通过该 “引用“可以修改其对应的维度                                                                                          
dim_proto0 = model.graph.input[0].type.tensor_type.shape.dim[0]
dim_proto3 = model.graph.input[0].type.tensor_type.shape.dim[3]
# 将该维度赋值为字符串,其维度不再为和dummy_input绑定的值
dim_proto0.dim_param = '1'
dim_proto3.dim_param = 'width'
dim_proto_0 = model.graph.output[0].type.tensor_type.shape.dim[0]
dim_proto_1 = model.graph.output[0].type.tensor_type.shape.dim[1]
dim_proto_0.dim_param = '1'
dim_proto_1.dim_param = 'width'
onnx.save(model, 'infer_rec_dynamic.onnx')

这样修改以后,就变成动态输入输出的模型了。

修改前:

修改后: 

但是转onnx需要主要指定 --opset_version 的版本为11,官方的文档解释如下:

Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. Each dimension value of the output tensor is: output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input "sizes" is not specified.

Version

This version of the operator has been available since version 11 of the default ONNX operator set.

如果不指定版本为11,有可能onnx转其他推理框架的模型可能会失败,这个地方需要注意。

当然,如果是自己通过pytorch代码也可以直接重新生成onnx为动态输入或动态输出模型,生成过程如下:

torch.onnx.export(rec_model, rec_input, "crnn.onnx",

dynamic_axes={'input' : {3 : 'width'},'output':{1:'width'}},

do_constant_folding=False,training=False,export_params=True)

参考:pytorch模型转onnx后修改模型输入维度 - 知乎 (zhihu.com)

<think>嗯,用户想了解如何将PyTorch模型转换为ONNX格式,我需要一步步理清楚这个过程。首先,我应该回忆一下PyTorch导出ONNX的基本步骤。记得需要用torch.onnx.export函数,对吧?这个函数需要模型、输入样例输出路径。然后,参数方面,比如input_namesoutput_names,这些是必须的吗? 对了,用户可能对动态轴设置不太清楚,这部分需要详细说明。比如,在处理不同批量大小的输入时,dynamic_axes参数怎么用。还有,验证转换后的模型是否正确,可能需要用ONNX Runtime来检查输出的一致性。这部分可能需要提醒用户安装onnxruntime包。 另外,用户提到最佳实践,我应该包括哪些内容呢?比如输入样例的规范化、使用最新版本的库、验证模型、文档记录这些点。可能还要提到模型简化工具,比如onnx-simplifier,这对优化模型结构有帮助。 用户的问题里还引用了两个站内引用,看起来是关于模型训练数据处理的。虽然模型转换不直接相关,但可能在转换前的模型准备阶段需要注意数据集的处理,比如输入输出的形状是否符合预期。例如引用[2]提到训练集的形状是(25000,),这里可能需要确认模型输入是否符合ONNX的要求,比如是否添加了批次维度。 还要考虑用户可能的错误,比如模型中有动态操作导致导出失败。这时候需要建议他们检查模型结构,或者使用opset_version参数调整。另外,是否支持所有PyTorch层?可能需要提醒用户某些自定义层可能需要额外处理。 最后,生成相关问题的时候,要围绕ONNX的部署、性能优化、支持的框架、转换失败处理等方面提问,这样能帮助用户进一步探索相关主题。</think>以下是将PyTorch模型转换为ONNX格式的完整指南及最佳实践: ### 一、基础转换流程 1. **安装依赖** ```bash pip install onnx onnxruntime ``` 2. **核心转换代码** ```python import torch import torch.onnx model = YourPyTorchModel().eval() # 确保模型处于评估模式 dummy_input = torch.randn(1, 3, 224, 224) # 匹配模型输入尺寸的虚拟数据 # 执行转换 torch.onnx.export( model, dummy_input, "model.onnx", export_params=True, # 包含模型参数 opset_version=13, # 推荐使用最新稳定版本 input_names=["input"], output_names=["output"], dynamic_axes={ "input": {0: "batch_size"}, # 支持动态批量维度 "output": {0: "batch_size"} } ) ``` 3. **验证转换结果** ```python import onnx onnx_model = onnx.load("model.onnx") onnx.checker.check_model(onnx_model) ``` ### 二、最佳实践 1. **输入样例标准化** - 使用与生产环境一致的输入维度 - 处理动态维度时显式声明dynamic_axes参数[^1] 2. **版本控制** ```python print(torch.__version__) # 推荐≥1.8.0 print(onnx.__version__) # 推荐≥1.10.0 ``` 3. **模型验证** ```python import onnxruntime as ort # 对比原始模型与ONNX模型的输出 ort_session = ort.InferenceSession("model.onnx") outputs = ort_session.run(None, {"input": dummy_input.numpy()}) torch_output = model(dummy_input) print("最大误差:", np.max(np.abs(outputs[0] - torch_output.detach().numpy()))) ``` 4. **模型优化** - 使用onnx-simplifier简化模型结构: ```bash python -m onnxsim input.onnx output.onnx ``` ### 三、常见问题处理 | 问题类型 | 解决方案 | |---------|----------| | 动态控制流报错 | 重写模型逻辑避免动态分支 | | 自定义算子缺失 | 实现自定义算子或选择支持的操作集 | | 形状推断错误 | 显式指定输入/输出形状 |
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值