YOLOV8中ciou损失函数原理
时间: 2025-04-29 13:30:06 浏览: 75
### YOLOv8 中 CIoU 损失函数工作原理
#### 定义与背景
CIoU (Complete Intersection over Union) 是一种改进型的目标检测损失函数,旨在解决传统 IoU 及其变体 GIoU 的局限性。相比其他版本的 IoU, CIoU 不仅考虑了预测边界框和真实边界框之间的重叠区域,还引入了额外的因素来提高定位精度。
#### 数学表达式
CIoU 计算公式如下:
\[
\text{CIoU} = \text{IoU} - \frac{\rho^2(b,b_{gt})}{c^2}-\alpha v,
\]
其中,
- \(b\) 和 \(b_{gt}\) 分别表示预测框和地面实况框;
- \(\rho\) 表示两个中心点间的欧几里得距离;
- \(c\) 代表覆盖两框最小区域的对角线长度;
- \(\alpha\) 是权重系数,\(v=(\frac{\pi}{4})(arctan(w/h)- arctan(w_{gt}/h_{gt}))\) 描述宽高比例差异的影响[^2]。
#### 工作机制
在训练过程中,模型会尝试最小化这个损失值。具体来说,
- **IoU 部分**:衡量的是预测框与实际标注框之间交集面积占并集的比例;当两者完全吻合时取最大值1。
- **闭包惩罚项**:通过减去封闭区域内两点间相对位置偏差平方除以其直径平方的形式,鼓励更紧密贴合目标物体的真实轮廓而不是简单扩大包围盒尺寸以增加交集部分。
- **长宽比一致性调整因子**:考虑到不同类别对象形状各异的特点,在此基础上加入了一个基于角度差别的修正量用来补偿因尺度变换带来的误差影响。
这种设计使得 CIoU 能够更加全面地评估候选框的质量,并引导网络学习更好的特征表征从而提升最终性能表现。
```python
def ciou_loss(pred_boxes, target_boxes):
"""
Calculate the CIoU loss between predicted boxes and target boxes.
Args:
pred_boxes: Predicted bounding box coordinates as tensors of shape [..., 4].
Format should be [center_x, center_y, width, height].
target_boxes: Ground truth bounding box coordinates with same format.
Returns:
Tensor containing element-wise CIoU losses.
"""
# Extracting centers and sizes from both sets of boxes
b1_x1, b1_y1, b1_w, b1_h = torch.split(pred_boxes, 1, dim=-1)
b2_x1, b2_y1, b2_w, b2_h = torch.split(target_boxes, 1, dim=-1)
# Calculating intersection area
inter_rect_x1 = torch.max(b1_x1 - b1_w / 2, b2_x1 - b2_w / 2)
inter_rect_y1 = torch.max(b1_y1 - b1_h / 2, b2_y1 - b2_h / 2)
inter_rect_x2 = torch.min(b1_x1 + b1_w / 2, b2_x1 + b2_w / 2)
inter_rect_y2 = torch.min(b1_y1 + b1_h / 2, b2_y1 + b2_h / 2)
inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1, min=0) * \
torch.clamp(inter_rect_y2 - inter_rect_y1, min=0)
# Computing union areas
b1_area = b1_w * b1_h
b2_area = b2_w * b2_h
union_area = b1_area + b2_area - inter_area
iou = inter_area / union_area
enclose_left_up = torch.min(torch.cat([pred_boxes[..., :2], target_boxes[..., :2]], dim=-1), dim=-1)[0]
enclose_right_down = torch.max(torch.cat([pred_boxes[..., 2:], target_boxes[..., 2:]], dim=-1), dim=-1)[0]
c = ((enclose_right_down[:, None, :] - enclose_left_up[:, :, :]) ** 2).sum(dim=-1)**0.5
p_center_dist_sqrd = (((target_boxes[..., :2] - pred_boxes[..., :2])**2)).sum(-1)
atan_pred_wh_ratio = torch.atan((pred_boxes[..., 2]/(pred_boxes[..., 3]+1e-16)))
atan_target_wh_ratio = torch.atan((target_boxes[..., 2]/(target_boxes[..., 3]+1e-16)))
v = (4/(np.pi**2)) * ((atan_pred_wh_ratio - atan_target_wh_ratio)**2)
alpha = v/(((1-iou)+v+1e-7))
ciou_term_2 = p_center_dist_sqrd/c**2
ci
阅读全文
相关推荐


















