Focal EIoU
时间: 2025-03-07 14:02:24 浏览: 90
### Focal EIoU的定义
Focal EIoU (Extended Intersection over Union) 是一种改进版的目标检测损失函数,旨在解决传统 IoU 及其变体在处理非重叠情况下的局限性。该方法不仅考虑了边界框之间的交并比,还加入了额外的距离惩罚项以提高定位精度[^1]。
具体来说,Focal EIoU 将标准 IoU 和扩展距离因子相结合,从而使得即使当预测框与真实框完全不相交时也能提供有效的梯度信号用于优化。这种机制有助于改善模型的学习过程,在复杂场景下获得更精确的结果。
### 数学表达式
设 \( B \) 表示预测边框,\( G \) 表示地面实况边框,则 Focal EIoU 定义如下:
\[ L_{\text{focal-eiou}}(B,G)=L_{\text{eiou}}(B,G)-\alpha\left(\frac{\sigma(B,G)}{|G|}\right)^2\log{(1-\text{IOU}(B,G)+\epsilon)} \]
其中,
- \( L_{\text{eiou}} \) 代表原始 Extended-IoU 损失;
- \( \alpha \) 控制焦点效应强度;
- \( \sigma(B,G) \) 测量两个矩形中心点间的欧几里得距离;
- \( |G| \) 是实际物体大小;
- \( \epsilon \) 防止数值不稳定的小常数;
此公式表明,随着 IoU 减少(即匹配程度降低),损失会增加更多,这鼓励网络更加关注难以分类的例子。
### Python实现代码
下面是一个简单的 PyTorch 版本的 Focal-EIoU 损失计算函数:
```python
import torch
from torch import nn
class FocalEIoULoss(nn.Module):
def __init__(self, alpha=0.5, epsilon=1e-6):
super(FocalEIoULoss, self).__init__()
self.alpha = alpha
self.epsilon = epsilon
def forward(self, boxes_pred, boxes_gt):
# Compute ious between each pair of predicted and ground truth bounding box.
ious = compute_ious(boxes_pred, boxes_gt)
# Calculate centers' distances squared normalized by gt area size.
dist_centers_normalized = calc_dist_centers_squared_over_area_size(
boxes_pred[:, :2], boxes_gt[:, :2], boxes_gt[:, 2:] - boxes_gt[:, :2])
# Apply focal term to emphasize hard negatives.
loss_focal_term = -(dist_centers_normalized * ((1 - ious + self.epsilon).clamp(min=self.epsilon)).log())
return (loss_focal_term * (ious < 0.5).float()).mean()
def compute_ious(preds, targets):
"""Compute pairwise IOUs."""
overlap_areas = get_overlap_areas(preds, targets)
pred_areas = calculate_box_areas(preds)
target_areas = calculate_box_areas(targets)
union_areas = pred_areas.unsqueeze(-1) + target_areas.t() - overlap_areas
return overlap_areas / union_areas.clamp(min=1e-7)
def get_overlap_areas(box_a, box_b):
max_xy = torch.min(box_a[:, None, 2:], box_b[None, :, 2:])
min_xy = torch.max(box_a[:, None, :2], box_b[None, :, :2])
inter_wh = (max_xy - min_xy).clamp(min=0.)
return inter_wh[..., 0] * inter_wh[..., 1]
def calculate_box_areas(boxes):
wh = boxes[..., 2:] - boxes[..., :2]
return wh[..., 0] * wh[..., 1]
def calc_dist_centers_squared_over_area_size(center_preds, center_gts, areas_gts):
diff_center = center_preds - center_gts
sq_distances = (diff_center ** 2).sum(dim=-1)
normed_sq_dists = sq_distances / areas_gts.prod(dim=-1)
return normed_sq_dists
```
阅读全文
相关推荐
















