yolov8原生detect.py
时间: 2025-04-24 14:24:37 浏览: 33
### YOLOv8 `detect.py` 文件解析
#### 概述
YOLOv8 的 `detect.py` 脚本用于执行对象检测任务。此脚本不仅提供了便捷的方式来进行推理,还支持多种配置选项来调整模型行为以适应不同需求。
#### 主要功能模块
##### 参数解析
命令行参数通过 Python 的 argparse 库进行管理,允许用户指定诸如权重路径、输入源(图片/视频流)、置信度阈值等重要设置[^1]。
```python
parser.add_argument('--weights', nargs='+', type=str, default='yolov8s.pt', help='model path(s)')
parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob')
```
##### 初始化环境与加载模型
该部分负责初始化必要的运行时环境并加载预训练的模型权重到内存中。对于 PyTorch 用户而言,这通常涉及调用 torch.hub 或者直接读取本地 .pt 文件中的序列化数据结构[^2]。
```python
device = select_device(opt.device)
model = DetectMultiBackend(opt.weights, device=device, dnn=opt.dnn, data=opt.data, fp16=opt.half)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size(imgsz=imgsz, s=stride) # check image size
```
##### 数据前处理
为了使原始图像能够被神经网络接受,在送入模型之前需对其进行标准化操作——包括但不限于尺寸缩放、通道顺序变换以及数值范围映射至 [-1, 1] 或 [0, 1] 区间内。这部分逻辑同样适用于批量处理多张图片的情况[^3]。
```python
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt and not jit)
for path, im, im0s, vid_cap, s in dataset:
im = torch.from_numpy(im).to(device)
im = im.half() if half else im.float()
im /= 255
```
##### 推理过程
一旦准备工作完成,就可以利用已加载好的模型对每帧画面实施预测工作了。这里会涉及到非极大抑制 (NMS) 来筛选最终输出框,并依据类别标签绘制边界矩形于原图之上以便可视化展示结果。
```python
pred = model(im, augment=augment, visualize=visualize)[0]
if isinstance(pred, tuple): pred = pred[0]
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
for det in pred:
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
for *xyxy, conf, cls in reversed(det):
c = int(cls)
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
annotator.box_label(xyxy, label, color=colors(c, True))
```
##### 后续处理与保存成果
最后一步则是将带有标注信息的新版本图像存储下来或是实时显示给用户查看。如果选择了录制模式,则还会创建相应的输出文件夹并将每一帧按序写入其中形成完整的视频片段。
```python
cv2.imshow(str(p), im0)
cv2.waitKey(1) # 1 millisecond
if save_vid:
if vid_path != save_path: # new video
vid_path = save_path
if isinstance(vid_writer, cv2.VideoWriter):
vid_writer.release()
if vid_cap: # video
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer.write(im0)
```
阅读全文
相关推荐

















