怎么用yolov5训练好的onnx数据集识别其中的物体 用python实现
时间: 2025-03-19 13:02:47 浏览: 41
### 加载并运行YOLOv5生成的ONNX模型进行目标检测
为了实现通过Python加载YOLOv5导出的ONNX模型来完成目标检测任务,可以按照以下方法操作:
#### 1. 安装必要的依赖库
首先需要安装 `onnxruntime` 和其他可能需要用到的相关库。可以通过pip命令轻松安装这些依赖项。
```bash
pip install onnxruntime numpy opencv-python matplotlib
```
#### 2. 导入所需模块
导入用于处理图像、执行推理以及显示结果所需的各个模块。
```python
import cv2
import numpy as np
import onnxruntime as ort
from PIL import Image, ImageDraw
```
#### 3. 初始化 ONNX 模型
利用 `onnxruntime` 提供的功能加载已经准备好的 `.onnx` 文件,并设置计算设备(CPU 或 GPU)以便后续使用。
```python
def load_onnx_model(model_path):
session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider' if ort.get_device() == 'GPU' else 'CPUExecutionProvider'])
input_name = session.get_inputs()[0].name
output_names = [output.name for output in session.get_outputs()]
return session, input_name, output_names
```
此函数返回三个重要参数:会话对象 (`session`)、输入张量名称 (`input_name`) 和输出张量列表 (`output_names`) [^2]。
#### 4. 图像预处理
由于神经网络通常接受特定尺寸和格式的数据作为输入,在实际应用前需对原始图片做适当调整。
```python
def preprocess_image(image_path, target_size=(640, 640)):
image = cv2.imread(image_path)
resized_img = cv2.resize(image, target_size)
normalized_img = resized_img / 255.0
transposed_img = np.transpose(normalized_img, axes=[2, 0, 1])
batched_img = np.expand_dims(transposed_img, axis=0).astype(np.float32)
return batched_img
```
上述代码片段实现了读取本地磁盘上的图片文件,将其大小更改为适合模型输入的形式,接着标准化像素值范围至 `[0..1]` 并重新排列通道顺序最后增加批次维度形成最终送入模型的数据形式 [^1]。
#### 5. 执行推理过程
定义一个辅助功能来进行预测工作流管理。
```python
def run_inference(session, input_data, input_name, output_names):
outputs = session.run(output_names, {input_name: input_data})
return outputs
```
该函数接收先前创建好的会话实例连同经过预处理后的测试样本一起传给它即可获得对应的结果向量集合 [^2]。
#### 6. 后处理逻辑
解析来自上一步骤产生的原始数值数组进而提取有用的信息比如边界框坐标类别标签置信度得分等等。
```python
def postprocess(outputs, conf_threshold=0.5, iou_threshold=0.45):
predictions = []
boxes = outputs[0][0][:,:, :4]
scores = outputs[0][0][:,:, 4:]
class_ids = np.argmax(scores, axis=-1)
filtered_indices = cv2.dnn.NMSBoxes(boxes.tolist(),
np.max(scores, axis=-1),
score_threshold=conf_threshold,
nms_threshold=iou_threshold)
for idx in filtered_indices.flatten():
box = boxes[idx]
cls_id = int(class_ids[idx])
confidence = float(np.max(scores[idx]))
prediction = {
"class": cls_id,
"confidence": confidence,
"box": tuple(box.astype(int))
}
predictions.append(prediction)
return predictions
```
这段程序负责筛选掉那些低于设定阈值或者与其他候选区域高度重叠的部分只保留最有可能代表真实物体的目标位置及其属性描述信息 。
#### 7. 结果可视化
为了让用户更加直观理解算法效果可绘制矩形标记覆盖原图之上同时附加文字说明具体分类情况。
```python
def draw_bounding_boxes(image_path, predictions, classes=["watermelon", "wintermelon"]):
img = Image.open(image_path)
draw = ImageDraw.Draw(img)
colors = [(255, 0, 0), (0, 255, 0)]
for pred in predictions:
label = f"{classes[pred['class']]} ({pred['confidence']*100:.2f}%)"
color = colors[pred["class"]]
bbox = pred["box"]
draw.rectangle(bbox, outline=color, width=3)
text_position = (bbox[0], bbox[1]-15)
draw.text(text_position, label, fill=color)
img.show()
```
这里假设只有两类即西瓜与冬瓜因此颜色映射简单明了当然也可以扩展到更多种类的情形下去适应不同的应用场景需求 [^3].
---
阅读全文
相关推荐


















