ONNX 输入batch修改

ONNX 输入batch修改

导出的onnx模型分为静态和动态输入两种,但一般用户会在导出后进行onnxsim操作,导致某些非全卷积的模型修改batch失败,比如transformer类其中reshape的attr属性会固定,修改相当麻烦,需要从源头重新导出是最佳选择,本文讨论在没有源头的情况下,尽量完成修改batch的修改,并固化每个op的输入输出tensor为正确shape。

依赖

  • onnx
  • onnxsim
import sys
import onnx
from onnx import shape_inference, helper
from onnxsim import simplify


filepath = sys.argv[1]
batch = int(sys.argv[2])
opset_version = int(sys.argv[3])

model = onnx.load(filepath)

# 创建一个新的模型图
new_graph = helper.make_graph(
    nodes=model.graph.node,                # 保留所有节点
    name="new_graph",                      # 保留原图名称
    inputs=model.graph.input,              # 保留输入
    outputs=model.graph.output,            # 保留输出
    initializer=model.graph.initializer,   # 保留初始化器
    value_info=None
)

# 构造一个新的模型
if not any(opset.domain == "" for opset in model.opset_import):
    model.opset_import.append(helper.make_opsetid(domain="", version=opset_version))
    
new_model = helper.make_model(new_graph, producer_name=model.producer_name, opset_imports=model.opset_import)
for idx, _input in enumerate(new_model.graph.input):
    _input.type.tensor_type.shape.dim[0].dim_value = batch
for idx, _output in enumerate(new_model.graph.output):
    _output.type.tensor_type.shape.dim[0].dim_value = batch
    
model_sim, success = simplify(new_model)
if not success:
    print("simplify failed")
    onnx.save(new_model, filepath.replace(".onnx", f"_b{batch}.onnx"))
else:
    onnx.save(model_sim, filepath.replace(".onnx", f"_b{batch}.onnx"))


ONNX-GraphSurgeon是一个用于操作ONNX模型的强大工具,它允许你在Python环境中对ONNX模型进行诸如添加、删除节点,修改连接等高级操作。如果你想通过它来改变模型的输入和输出维度,你可以按照以下步骤操作: 1. 首先,你需要安装`onnx`和`onnx-graphsurgeon`库: ``` pip install onnx onnx-graphsurgeon ``` 2. 导入所需的模块: ```python import onnx from onnx_graphsurgeon.importers import Importer from onnx_graphsurgeon.nodes import Node from onnx_graphsurgeon.graph import Graph ``` 3. 加载原始ONNX模型: ```python model = onnx.load("your_model.onnx") graph = Importer.from_onnx(model) ``` 4. 查找需要修改输入输出维度的特定节点。通常,这可能是卷积层、全连接层等有固定输入输出维度的节点。例如,假设你想增加一个新通道数到某个张量的输入: ```python input_node = graph.inputs[0] # 查找输入节点 input_node.shape = [batch_size, new_channel_count, input_height, input_width] # 修改输入形状 ``` 5. 对于输出维度的修改,如果模型有一个明确的输出张量,同样可以找到该节点,并更新其shape属性。 6. 调整完成后,保存修改后的模型: ```python modified_graph = graph.to_dot() new_model = onnx.save(modified_graph, "modified_model.onnx") ``` 请注意,这个过程依赖于模型的具体结构,特别是你如何能找到正确的输入和输出节点。如果你不确定哪个节点对应,可能需要查看ONNX模型的元数据或者使用额外的工具帮助定位。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

上单之光

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值