Focal Loss损失函数代码
时间: 2025-06-30 15:55:01 浏览: 8
### Focal Loss 损失函数的实现
#### PyTorch 实现
以下是基于 PyTorch 的 Focal Loss 损失函数实现代码:
```python
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
"""
:param inputs: 预测的概率分布 (经过 sigmoid 或 softmax 后的结果),形状为 [batch_size, num_classes]
:param targets: 真实标签,形状为 [batch_size, num_classes](one-hot 编码)
:return: 计算后的焦点损失值
"""
eps = 1e-7
BCE_loss = -targets * torch.log(inputs + eps) - (1 - targets) * torch.log(1 - inputs + eps)
pt = torch.exp(-BCE_loss) # 获取预测概率
FL_loss = self.alpha * (1 - pt)**self.gamma * BCE_loss # 应用 Focal Loss 公式
if self.reduction == 'mean':
return torch.mean(FL_loss)
elif self.reduction == 'sum':
return torch.sum(FL_loss)
else:
return FL_loss
```
上述代码实现了二分类或多分类场景下的 Focal Loss 函数[^1]。`alpha` 参数用于调整正负样本之间的不平衡问题,而 `gamma` 控制难分样本的影响程度。
---
#### TensorFlow 实现
以下是基于 TensorFlow 的 Focal Loss 损失函数实现代码:
```python
import tensorflow as tf
def focal_loss(gamma=2., alpha=.25):
def focal_loss_fixed(y_true, y_pred):
"""
:param y_true: 真实标签,形状为 [batch_size, num_classes](one-hot 编码)
:param y_pred: 预测的概率分布,形状为 [batch_size, num_classes]
:return: 计算后的焦点损失值
"""
eps = 1e-7
y_pred = tf.clip_by_value(y_pred, eps, 1. - eps) # 防止 log(0) 导致梯度爆炸
ce = -y_true * tf.math.log(y_pred) # 交叉熵计算
weight = tf.pow(1. - y_pred, gamma) * (alpha * y_true + (1 - alpha) * (1 - y_true))
fl = weight * ce # 应用 Focal Loss 公式
return tf.reduce_mean(fl)
return focal_loss_fixed
```
此代码适用于 Keras 和 TensorFlow 中的模型训练过程。通过设置参数 `gamma` 和 `alpha` 可以灵活控制损失函数的行为[^3]。
---
### 关键点解析
1. **Focal Loss 的作用**
Focal Loss 是一种改进版的交叉熵损失函数,主要用于解决类别不均衡问题以及降低简单样本对总损失的影响。它通过对难分样本施加更大的权重来提高模型性能[^2]。
2. **Alpha 权重的作用**
Alpha 参数可以用来平衡正负类别的比例差异。当数据集中某一类别的样本数量远多于另一类别时,可以通过调节 `alpha` 值缓解这种不平衡带来的影响。
3. **Gamma 超参的意义**
Gamma 参数决定了容易分类样本的下降速度。较大的 `gamma` 值会使简单样本的贡献进一步减小,从而让模型更加关注困难样本的学习。
---
阅读全文
相关推荐

















