CrossEntropyLoss和手写交叉熵的区别
总结
torch.nn.CrossEntropyLoss(weight, reduction=“mean”)的计算方式和一般做法不一样,进行加权计算后,不是直接loss除以batch数量,而是batch中每个数据的标签对应的类别权重 把batch个权重加起来作为除数,被除数是loss
交叉熵损失函数原理
这里就不做介绍了,很多博客和视频已经讲的很清楚了。这里引用一下@b站up同济子豪兄之前的一次动态截图
不带权重的交叉熵计算
这里已经也有很多介绍了,不带权重的情况下,torch.nn.CrossEntropyLoss(reduction=“mean”)的计算和公式里一样。
另外,虽然pytorch里面CrossEntropyLoss的target输入要求是标签或者Probabilities,但是target从标签(batch_size, 1(num_classes))转换成独热编码(torch.nn.functional.one_hot),输出结果是一样的。我感觉是转换成独热编码就相当于target是那个Probabilities了,具体计算过程没细看,结果和手搓的代码一样。
import torch.nn as nn
import torch.nn.functional as F
weight=torch.tensor([1,1,1,1]).float() # 这里权重都是1,所以结果一样
loss_func_mean = nn.CrossEntropyLoss(weight, reduction="mean")
def ce_loss(y_pred, y_true, weight):
# 计算 log_softmax,这是更稳定的方式
log_probs = F.log_softmax(y_pred, dim=-1)
# 计算加权损失
loss = -(weight * y_true * log_probs).sum(dim=-1).mean()
return loss
pre_data = torch.tensor([[0.8, 0.5, 0.2, 0.5],
[0.2, 0.9, 0.3, 0.2],
[0.4, 0.3, 0.7, 0.1],
[0.1, 0.2, 0.4, 0.8]], dtype=torch.float)
tgt_index_data