不同损失函数ciou
时间: 2025-04-20 13:37:36 浏览: 11
### CIoU Loss Function in Object Detection Algorithms
In object detection, ensuring accurate bounding box predictions is crucial for model performance. The Complete Intersection over Union (CIoU) loss function addresses this by improving upon previous methods like IoU and GIoU losses.
The CIoU loss not only considers the overlap area between predicted boxes but also incorporates additional factors such as aspect ratio consistency and central point distance[^1]. This comprehensive approach leads to faster convergence during training while maintaining higher accuracy compared to simpler loss functions.
Mathematically, CIoU can be expressed as follows:
\[
\text{CIoU} = \text{IOU} - \frac{\rho^2(b,b_{gt})}{c^2}-\alpha v
\]
Where:
- \(b\) represents the prediction box,
- \(b_{gt}\) denotes ground truth box,
- \(\rho\) measures Euclidean distance between centers of two boxes,
- \(c\) stands for diagonal length enclosing both boxes,
- \(\alpha\) controls trade-off based on how much shape difference affects overall score,
- \(v=\left(4/\pi ^2\right)\times arctan(w_gt/h_gt)-arctan(w/h)^2\) quantifies aspect-ratio discrepancy.
Implementing CIoU within an object detector involves modifying the loss calculation part where bounding-box regressions occur. Here's a Python snippet demonstrating its implementation using PyTorch framework:
```python
import torch
def ciou_loss(preds, targets, eps=1e-7):
# Convert from center format to corner format
preds_center_x, preds_center_y, preds_w, preds_h = preds[:, 0], preds[:, 1], preds[:, 2], preds[:, 3]
target_center_x, target_center_y, target_w, target_h = targets[:, 0], targets[:, 1], targets[:, 2], targets[:, 3]
w_intersect = torch.min((preds_center_x + preds_w / 2), (target_center_x + target_w / 2)) \
- torch.max((preds_center_x - preds_w / 2), (target_center_x - target_w / 2))
h_intersect = torch.min((preds_center_y + preds_h / 2), (target_center_y + target_h / 2)) \
- torch.max((preds_center_y - preds_h / 2), (target_center_y - target_h / 2))
g_w = torch.max((preds_center_x + preds_w / 2), (target_center_x + target_w / 2)) \
- torch.min((preds_center_x - preds_w / 2), (target_center_x - target_w / 2))
g_h = torch.max((preds_center_y + preds_h / 2), (target_center_y + target_h / 2)) \
- torch.min((preds_center_y - preds_h / 2), (target_center_y - target_h / 2))
ac_uion = g_w * g_h + eps
inter_area = w_intersect.clamp(min=0) * h_intersect.clamp(min=0)
union = preds_w * preds_h + target_w * target_h - inter_area + eps
ious = inter_area / union
center_distance_squared = ((preds_center_x - target_center_x)**2 +
(preds_center_y - target_center_y)**2)
diag_enclosing_box_squared = g_w ** 2 + g_h ** 2 + eps
u = (center_distance_squared / diag_enclosing_box_squared).clamp(max=1.)
v = (4 / math.pi ** 2) * torch.pow(
torch.atan(target_w / target_h) -
torch.atan(preds_w / preds_h),
2
)
alpha = v / ((1 - ious) + v).clamp(min=eps)
cious = ious - u - alpha * v
return 1 - cious.mean()
```
阅读全文
相关推荐


















