pytorch学习:loss为什么要加item()

本文深入探讨PyTorch0.4.0版本中Variable与Tensor融合后的动态图机制,解析如何避免显存溢出问题,提供正确记录loss信息的方法,并分享作者亲身经历的教训。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

作者:陈诚
链接:https://www.zhihu.com/question/67209417/answer/344752405
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

PyTorch 0.4.0版本去掉了Variable,将Variable和Tensor融合起来,可以视Variable为requires_grad=True的Tensor。其动态原理还是不变。在获取数据的时候也变得更优雅:使用loss += loss.detach()来获取不需要梯度回传的部分。或者使用loss.item()直接获得所对应的python数据类型。============================================================

以下为原回答:算是动态图的一个坑吧。记录loss信息的时候直接使用了输出的Variable。应该不止我经历过这个吧…久久不用又会不小心掉到这个坑里去…for data, label in trainloader:

    loss = criterion(out, label)
    loss_sum += loss     # <--- 这里

运行着就发现显存炸了观察了一下发现随着每个batch显存消耗在不断增大…参考了别人的代码发现那句loss一般是这样写 /(ㄒoㄒ)/

loss_sum += loss.data[0]

这是因为输出的loss的数据类型是Variable。而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大

那么消耗的显存也就越来越大

总之使用Variable的数据时候要非常小心。不是必要的话尽量使用Tensor来进行计算… 包括数据的输入时候,如果“过早”把数据丢到Variable里面去,那么可能也会被系统视为网络的一部分。所以,要投入的时候再把数据丢到Variable里面去吧~题外话想更多感受动态图的话,可以通过Variable的grad_fun来观察到该Variable是通过什么运算得到的(前提是前面的Variable的required_grad置为True)。

大概是这样

z = x + y
z.grad_fn
out:
<AddBackward1 at 0x107286240>

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值