import torch
x = [
torch.tensor(i, dtype=torch.float, requires_grad=True)
for i in range(1, 10)
]
with torch.no_grad():
y = map(lambda x: x * 2, x)
print(sum(y)) # tensor(90., grad_fn=<AddBackward0>)
虽然 y
是在 torch.no_grad
环境下创建的,但是由于 map
具有惰性求值的特性,所以 y
实际是在 sum(y)
被调用时才被计算出来。这时已经退出 torch.no_grad
环境,因此 sum(y)
有梯度。
在 PyTorch 中,torch.no_grad
主要用于禁止计算图的构建,以减少内存占用和计算开销。另一方面,Python 内置函数 map
具有惰性求值的特性,这可能会导致 torch.no_grad
失效,从而影响计算结果。
torch.no_grad
torch.no_grad
是 PyTorch 提供的一个上下文管理器,它的作用是在代码执行时禁止计算梯度。这样可以减少不必要的计算,提高执行效率。例如
import torch
torch.manual_seed(0)
x = torch.randn(3, requires_grad=True)
with torch.no_grad():
y = x * 2
print(y.requires_grad) # False
在 torch.no_grad()
作用域下计算 y = x * 2
,由于梯度被禁用,y.requires_grad
为 False
,即 PyTorch 不会为 y
构建计算图。
Python map
的惰性求值
在 Python 中,map
是一个惰性计算的函数,它返回一个迭代器,并不会立即执行计算。例如
def square(x: int) -> int:
print(f"Computing square of {x}")
return x * x
numbers = [1, 2, 3]
squares = map(square, numbers)
print(squares)
print("-" * 10)
print(list(squares))
执行上述代码,将会输出
<map object at 0x7f10fdf17730>
----------
Computing square of 1
Computing square of 2
Computing square of 3
[1, 4, 9]
可以看到,squares
并不是一个列表,而是一个 map
类型的对象。此外,Computing square of ...
的在 ----------
之后被打印,证明 square
函数在 map(square, numbers)
中没有被执行,而是在 list(squares)
中才被执行的。这是因为 map
对象只有在被遍历(例如 list
或 sum
)时才真正执行。
map
惰性求值导致 torch.no_grad
失效
结合 torch.no_grad
和 map
,可能会出现意想不到的行为。例如
import torch
x = [
torch.tensor(i, dtype=torch.float, requires_grad=True)
for i in range(1, 10)
]
with torch.no_grad():
y = map(lambda x: x * 2, x)
print(sum(y).requires_grad) # True
为什么 y
仍然有梯度?
因为 map(lambda x: x * 2, x)
只创建了一个 map
对象,并没有真正执行 x * 2
计算。torch.no_grad()
作用域结束后,计算 sum(y)
时,x * 2
计算才真正发生。这时 torch.no_grad()
已经失效,因此 PyTorch 仍然会追踪梯度。
解决方案
如果希望 torch.no_grad
正确生效,可以强制 map
立即执行
with torch.no_grad():
y = list(map(lambda x: x * 2, x))
print(sum(y).requires_grad) # False
这样 y
在 torch.no_grad
作用域内就完成计算,确保 sum(y)
不会被追踪梯度。