torch中的inplace操作问题解决方法

本文介绍了PyTorch中的inplace操作,它是一种高效内存使用方法,但使用不当会在计算梯度或反传时出错。列举了典型的inplace操作,如+=、*=和不合适的原地切片赋值。通过出错案例对比,还给出判断inplace操作的方法,即看切片[]位置。

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


一、inplace操作是什么?

可以参考:PyTorch中的In-place操作是什么?为什么要避免使用这种操作?
简单来说,就是你写的代码采用了一种更为高效的内存使用方法,有点像copy(),deepcopy(),切片等操作的关系,如果你不小心使用了inplace操作的话,则在torch的计算梯度autograd()或者反传backward()位置处出现以下提示:
Runtime Error: Found an in-place operation that changes one the variables needed by gradient computation

典型的inplace操作有哪些?

1.+=,*=等操作
例如:

 a[:,0,:,:]+=b[:,0,:,:]

2.不合适的原地切片并赋值
例如:

 a[:,0,:,:]=do_sth(a[:,0,:,:])#do_sth是一个函数

这也是我这次出错的原因
这也是我这次出错的原因

二、出错案例

代码如下(示例):

def forward(self, x):
        x_ = torch.zeros(self.layer(x[..., 0]).shape + (steps,), device=x.device)
        for step in range(steps):
            x_[..., step] = self.layer(x[..., step])
        if self.bn is not None:
            for i in range(steps):
                x_[:,:,:,:,i] = self.bn(x_[:,:,:,:,i])
        return x_

正确的代码如下:

def forward(self, x):

        x_ = torch.zeros(self.layer(x[..., 0]).shape + (steps,), device=x.device)
        for step in range(steps):
            x_[..., step] = self.layer(x[..., step])

        x_2 = torch.zeros(self.layer(x[..., 0]).shape + (steps,), device=x.device)
        if self.bn is not None:
            for i in range(steps):
                x_2[:,:,:,:,i] = self.bn(x_[:,:,:,:,i])
        return x_2 if self.bn is not None else x_

对比两段代码可以发现,正确的代码重新生成了一个tensor x_2,并进行切片赋值,而错误的代码是在原地进行的赋值修改,是torch里面的一种典型inplace操作。


总结

在查看其他博客的时候,看到一个有意思的总结,torch的inplace操作判断只要看切片的[]在哪里,如果只出现在等号的右边,那么一定不是inplace操作,如果等号的左边出现了切片[],那么就需要思考了。
再举两个例子吧

例子1:a=a[:,1:,:,:]

不是in-place.

例子2:a=a.permute(0,3,1,2)

不是in-place.

例子3:a[:,1:,:,:]=函数1(a[:,1:,:,:])

是in-place.

例子4:b[:,1:,:,:]=函数1(a[:,1:,:,:])

不是in-place.

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值