focal loss 损失函数
时间: 2025-01-13 11:58:31 浏览: 60
### Focal Loss Function in Machine Learning
Focal损失函数是一种专门用于处理类别不平衡问题的改进型交叉熵损失函数[^2]。传统交叉熵损失对于容易分类的例子给予过多的关注,这在面对高度偏斜的数据分布时可能导致训练过程偏向多数类而忽略少数类。
为了缓解这一现象,Focal损失引入了一个调节因子\((1-p_t)^\gamma\),其中\(p_t\)代表模型预测的概率,当真实标签为正样本时取预测概率;反之,则取其补数即\(1-\hat{y}\)。参数\(\gamma \geq 0\)用来控制简单样例的重要性衰减速度——随着\(\gamma\)增大,易分样本的影响逐渐减弱,从而使得网络能够更加专注于难以区分的负样本上。
具体来说,给定一个二元分类任务中的单一样本及其对应的ground truth \(y ∈ {−1,+1}\),以及该样本属于正类别的估计概率\(\hat{y} = σ(x)\),那么原始形式下计算得到的标准二元交叉熵可以写作:
\[CE(p,y)=\left\{
\begin{array}{ll}
-log(p), & y=+1 \\
-log(1-p),& y=-1
\end{array}
\right.
\]
而在应用了Focal损失之后变为:
\[FL(p,y)=(1-p_{t})^{\gamma }*CE(p,y)\]
这里特别强调的是,通过调整超参\(\alpha\)和\(\gamma\),可以使算法更好地适应不同应用场景下的需求变化。例如,在某些极端不平衡的情况下适当增加这两个系数可以帮助提升对稀少事件的学习效果。
```python
import torch.nn.functional as F
def focal_loss(input, target, alpha=0.25, gamma=2):
ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
p_t = (target * input.sigmoid() + (1 - target) * (1 - input.sigmoid())).clamp(min=eps)
fl = -(alpha * ((1 - p_t)**gamma) * ce_loss).mean()
return fl
```
阅读全文
相关推荐
















