提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
torch中的inplace操作问题解决方法
一、inplace操作是什么?
可以参考:PyTorch中的In-place操作是什么?为什么要避免使用这种操作?
简单来说,就是你写的代码采用了一种更为高效的内存使用方法,有点像copy(),deepcopy(),切片等操作的关系,如果你不小心使用了inplace操作的话,则在torch的计算梯度autograd()或者反传backward()位置处出现以下提示:
Runtime Error: Found an in-place operation that changes one the variables needed by gradient computation
典型的inplace操作有哪些?
1.+=,*=等操作
例如:
a[:,0,:,:]+=b[:,0,:,:]
2.不合适的原地切片并赋值
例如:
a[:,0,:,:]=do_sth(a[:,0,:,:])#do_sth是一个函数
这也是我这次出错的原因
这也是我这次出错的原因
二、出错案例
代码如下(示例):
def forward(self, x):
x_ = torch.zeros(self.layer(x[..., 0]).shape + (steps,), device=x.device)
for step in range(steps):
x_[..., step] = self.layer(x[..., step])
if self.bn is not None:
for i in range(steps):
x_[:,:,:,:,i] = self.bn(x_[:,:,:,:,i])
return x_
正确的代码如下:
def forward(self, x):
x_ = torch.zeros(self.layer(x[..., 0]).shape + (steps,), device=x.device)
for step in range(steps):
x_[..., step] = self.layer(x[..., step])
x_2 = torch.zeros(self.layer(x[..., 0]).shape + (steps,), device=x.device)
if self.bn is not None:
for i in range(steps):
x_2[:,:,:,:,i] = self.bn(x_[:,:,:,:,i])
return x_2 if self.bn is not None else x_
对比两段代码可以发现,正确的代码重新生成了一个tensor x_2,并进行切片赋值,而错误的代码是在原地进行的赋值修改,是torch里面的一种典型inplace操作。
总结
在查看其他博客的时候,看到一个有意思的总结,torch的inplace操作判断只要看切片的[]在哪里,如果只出现在等号的右边,那么一定不是inplace操作,如果等号的左边出现了切片[],那么就需要思考了。
再举两个例子吧
例子1:a=a[:,1:,:,:]
不是in-place.
例子2:a=a.permute(0,3,1,2)
不是in-place.
例子3:a[:,1:,:,:]=函数1(a[:,1:,:,:])
是in-place.
例子4:b[:,1:,:,:]=函数1(a[:,1:,:,:])
不是in-place.