关于pytorch交叉熵损失函数使用weight参数来平衡各个类别,即加权交叉熵损失函数

本文详细解释了PyTorch中的CrossEntropyLoss函数在处理带权重的交叉熵时的计算方式,指出与传统计算方法的区别,特别是当使用reduction=mean时,实际的计算涉及将每个样本对应类别权重加总作为分母。作者还提供了示例代码以验证结果的一致性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

总结

torch.nn.CrossEntropyLoss(weight, reduction=“mean”)的计算方式和一般做法不一样,进行加权计算后,不是直接loss除以batch数量,而是batch中每个数据的标签对应的类别权重 把batch个权重加起来作为除数,被除数是loss

交叉熵损失函数原理

这里就不做介绍了,很多博客和视频已经讲的很清楚了。这里引用一下@b站up同济子豪兄之前的一次动态截图
具体可以搜索这个动态,关键词交叉熵

不带权重的交叉熵计算

这里已经也有很多介绍了,不带权重的情况下,torch.nn.CrossEntropyLoss(reduction=“mean”)的计算和公式里一样。
另外,虽然pytorch里面CrossEntropyLoss的target输入要求是标签或者Probabilities,但是target从标签(batch_size, 1(num_classes))转换成独热编码(torch.nn.functional.one_hot),输出结果是一样的。我感觉是转换成独热编码就相当于target是那个Probabilities了,具体计算过程没细看,结果和手搓的代码一样。

import torch.nn as nn
import torch.nn.functional as F

weight=torch.tensor([1,1,1,1]).float()  # 这里权重都是1,所以结果一样
loss_func_mean = nn.CrossEntropyLoss(weight, reduction="mean")
def ce_loss(y_pred, y_true, weight):

    # 计算 log_softmax,这是更稳定的方式
    log_probs = F.log_softmax(y_pred, dim=-1)

    # 计算加权损失
    loss = -(weight * y_true * log_probs).sum(dim=-1).mean()

    return loss
pre_data = torch.tensor([[0.8, 0.5, 0.2, 0.5],
                         [0.2, 0.9, 0.3, 0.2],
                         [0.4, 0.3, 0.7, 0.1],
                         [0.1, 0.2, 0.4, 0.8]], dtype=torch.float)
tgt_index_data 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值