HuggingFace Transformers 模型导出到 ONNX 格式

本文介绍了如何使用HuggingFace的Optimum库将Transformer模型转换为ONNX格式,以便于部署和跨框架兼容。提供了两种方法:通过命令行接口和编程方式导出,以及ONNX的优势和应用示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

HuggingFace Transformers 模型导出到 ONNX 格式

0. 引言

ONNX 是一种开放格式,用于表示机器学习模型。ONNX 定义了一组通用运算符(机器学习和深度学习模型的构建基块)和通用文件格式,使 AI 开发人员能够将模型与各种框架、工具、运行时和编译器一起使用。了解更多信息

官网地址: https://onnx.ai/

1. 使用 Optimum

在生产环境中部署 🤗 Transformers 模型通常需要将模型导出为可在专用运行时和硬件上加载和执行的序列化格式,或者可以从中受益。

🤗 Optimum 是 Transformer 的扩展,它允许通过其 exporters 模块将模型从 PyTorch 或 TensorFlow 导出为 ONNX 和 TFLite 等序列化格式。🤗 Optimum还提供了一组性能优化工具,以最高效率在目标硬件上训练和运行模型。

今天尝试使用 🤗 Optimum 将 Transformer 模型导出🤗到 ONNX。

ONNX(开放神经网络交换)是一种开放标准,它定义了一组通用运算符和通用文件格式,用于表示各种框架(包括 PyTorch 和 TensorFlow)中的深度学习模型。当模型导出为 ONNX 格式时,这些运算符用于构造计算图(通常称为中间表示),该图表示通过神经网络的数据流。

通过公开具有标准化运算符和数据类型的图形,ONNX 可以轻松地在框架之间切换。例如,可以在 PyTorch 中训练的模型导出为 ONNX 格式,然后导入到 TensorFlow 中(反之亦然)。

导出为 ONNX 格式后,模型可以:

  • 通过图形优化和量化等技术针对推理进行了优化。

  • 通过 ORTModelForXXX 类与 ONNX 运行时一起运行,这些类遵循与您在 Transformer 中🤗习惯使用的 API 相同的 AutoModel API。

<think>嗯,用户想了解如何将机器学习模型导出ONNX格式。让我先理清楚已有的信息和引用内容。根据提供的引用,有几个方法可以导出ONNX模型,比如使用scikit-learn的skl2onnx库,或者HuggingFace的Optimum库,还有PyTorch自带的导出功能。 首先,用户的问题比较广泛,可能需要分步骤回答。我应该先介绍通用的步骤,再具体到不同的框架。比如,对于scikit-learn模型,可以用skl2onnx,而PyTorch模型可以用torch.onnx.export。同时,HuggingFacetransformers模型可以通过Optimum或transformers.onnx来处理。 然后,我需要确保每个步骤都符合用户的要求,比如行内公式用$...$,独立公式用$$...$$。例如,提到线性回归模型时,可以插入公式$y = WX + b$。此外,要注意引用标识的正确添加,比如在提到安装依赖的地方引用[^1],或者在优化kv_cache时引用[^4]。 可能用户还会关心导出时的常见问题,比如数据类型换,如引用[5]提到的INT64错误,需要换为INT32或FP32。这也是需要注意的点,应该在注意事项里提到。 另外,用户提供的示例中包括了Python代码,所以在回答中也应该包含具体的代码示例,比如scikit-learn导出ONNX的代码,以及PyTorch模型导出代码。同时,要确保代码格式正确,使用反引号包裹。 还要考虑用户可能的后续问题,比如如何验证导出模型是否正确,或者如何进行性能优化。这些可以作为相关问题列出,比如导出后的验证步骤,或者如何利用ONNX Runtime进行推理。 最后,需要检查所有引用是否正确对应,例如引用[1]是关于安装依赖,引用[2]是关于scikit-learn导出,引用[3]和[4]涉及HuggingFace模型和LLM导出,引用[5]处理换中的错误。确保每个引用的地方都是自然衔接的,比如在提到模型部署优化时引用,在数据类型问题时引用[^5]。 总结一下,回答的结构应该是:简介ONNX,分框架介绍导出方法(scikit-learn、PyTorch、HuggingFace),注意事项,常见问题处理,最后相关问题。每个部分都插入对应的代码示例、公式和引用标识。</think>### 如何将机器学习模型导出ONNX格式 #### 一、ONNX格式简介 ONNX(Open Neural Network Exchange)是一种开放的模型表示格式,支持跨框架和平台部署。通过ONNX,用户可以将PyTorch、TensorFlow、Scikit-learn等框架训练的模型换为统一格式,方便在ONNX Runtime等推理引擎中运行[^2]。 --- #### 二、通用导出步骤 1. **安装依赖库** 根据框架选择安装对应工具: ```bash pip install onnx onnxruntime skl2onnx # 通用库 pip install torch # PyTorch模型 pip install optimum # HuggingFace模型[^3] ``` 2. **模型训练与准备** 以线性回归为例,模型公式为: $$y = WX + b$$ 训练完成后需确保模型已序列化(如`.pkl`文件)。 3. **导出ONNX格式** **方法1:Scikit-learn模型** 使用`skl2onnx`库: ```python from skl2onnx import convert_sklearn from skl2onnx.common.data_types import FloatTensorType # 训练模型 from sklearn.linear_model import LinearRegression model = LinearRegression().fit(X_train, y_train) # 定义输入类型(假设输入维度为3) initial_type = [('float_input', FloatTensorType([None, 3]))] onnx_model = convert_sklearn(model, initial_types=initial_type) # 保存文件 with open("model.onnx", "wb") as f: f.write(onnx_model.SerializeToString()) ``` [^2] **方法2:PyTorch模型** 使用`torch.onnx.export`: ```python import torch # 定义示例输入 dummy_input = torch.randn(1, 3) # 输入维度为3 model = torch.nn.Linear(3, 1) # 定义模型 # 导出ONNX torch.onnx.export( model, dummy_input, "model.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}} # 支持动态批次 ) ``` **方法3:HuggingFace Transformers模型** 使用`optimum`库: ```python from optimum.onnxruntime import ORTModelForSequenceClassification from transformers import AutoTokenizer, AutoModelForSequenceClassification model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") # 导出ONNX model.save_pretrained("onnx_model/", use_onnx=True) ``` [^3] 4. **验证ONNX模型** 使用ONNX Runtime推理测试: ```python import onnxruntime as ort import numpy as np sess = ort.InferenceSession("model.onnx") input_name = sess.get_inputs()[0].name output = sess.run(None, {input_name: np.random.randn(1,3).astype(np.float32)}) ``` --- #### 三、注意事项 1. **数据类型兼容性** ONNX不支持INT64类型权重,需强制换为INT32或FP32(参考代码): ```python from onnx import helper, numpy_helper # 遍历模型节点替换INT64类型 ``` 2. **优化导出结构** - 对复杂模型(如LLM),需简化冗余操作(如Qwen模型的`rearrange`操作)。 - 使用`opset_version`参数指定兼容的操作集版本。 --- #### 四、常见问题 1. **动态维度支持** 在导出时通过`dynamic_axes`参数定义可变维度(如动态批次大小)。 2. **输入/输出对齐** 确保导出时的输入名称和维度与原始模型一致,避免推理错误。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值