欢迎收藏Star我的Machine Learning Blog:https://github.com/purepisces/Wenqing-Machine_Learning_Blog。如果收藏star, 有问题可以随时与我交流, 谢谢大家!
处理类别不平衡
在欺诈检测、点击预测或垃圾邮件检测等机器学习用例中,通常会遇到标签不平衡的问题。根据具体用例,可以使用以下几种策略来处理不平衡。
在损失函数中使用类别权重
例如,在垃圾邮件检测问题中,非垃圾邮件数据占95%,而垃圾邮件数据仅占5%。我们希望对非垃圾邮件类别给予更高的惩罚。在这种情况下,可以通过权重修改熵损失函数。
// w0是类别0的权重,w1是类别1的权重
loss_function = -w0 * ylog(p) - w1*(1-y)*log(1-p)
情况1:正确分类(垃圾邮件)
- 真实标签: 垃圾邮件 (y = 1)
- 预测概率: p = 0.9 (对垃圾邮件的高置信度)
- 类别权重: w_0 = 1.05 (非垃圾邮件), w_1 = 20 (垃圾邮件)
由于真实标签是垃圾邮件且模型预测为垃圾邮件,这是正确分类。
loss = − w 1 ⋅ y ⋅ log ( p ) − w 0 ⋅ ( 1 − y ) ⋅ log ( 1 − p ) \text{loss} = - w_1 \cdot y \cdot \log(p) - w_0 \cdot (1 - y) \cdot \log(1 - p) loss=−w1⋅y⋅log(p)−w0⋅(1−y)⋅log(1−p)
loss = − 20 ⋅ 1 ⋅ log ( 0.9 ) − 1.05 ⋅ 0 ⋅ log ( 0.1 ) \text{loss} = - 20 \cdot 1 \cdot \log(0.9) - 1.05 \cdot 0 \cdot \log(0.1) loss=−20⋅1⋅log(0.9)−1.05⋅0⋅log(0.1)
loss = − 20 ⋅ log ( 0.9 ) \text{loss} = - 20 \cdot \log(0.9) loss=−20⋅log(0.9)
loss ≈ − 20 ⋅ ( − 0.105 ) \text{loss} \approx - 20 \cdot (-0.105) loss≈−20⋅(−0.105)
loss ≈ 2.1 \text{loss} \approx 2.1 loss≈2.1
情况2:错误分类(垃圾邮件)
- 真实标签: 垃圾邮件 (y = 1)
- 预测概率: p = 0.3 (对垃圾邮件的低置信度)
- 类别权重: w_0 = 1.05 (非垃圾邮件), w_1 = 20 (垃圾邮件)
由于真实标签是垃圾邮件但模型预测为非垃圾邮件,这是错误分类。
loss = − w 1 ⋅ y ⋅ log ( p ) − w 0 ⋅ ( 1 − y ) ⋅ log ( 1 − p ) \text{loss} = - w_1 \cdot y \cdot \log(p) - w_0 \cdot (1 - y) \cdot \log(1 - p) loss=