Python map 惰性求值导致 torch.no_grad 失效

import torch
x = [
    torch.tensor(i, dtype=torch.float, requires_grad=True)
    for i in range(1, 10)
]
with torch.no_grad():
    y = map(lambda x: x * 2, x)
print(sum(y))  # tensor(90., grad_fn=<AddBackward0>)

虽然 y 是在 torch.no_grad 环境下创建的,但是由于 map 具有惰性求值的特性,所以 y 实际是在 sum(y) 被调用时才被计算出来。这时已经退出 torch.no_grad 环境,因此 sum(y) 有梯度。

在 PyTorch 中,torch.no_grad 主要用于禁止计算图的构建,以减少内存占用和计算开销。另一方面,Python 内置函数 map 具有惰性求值的特性,这可能会导致 torch.no_grad 失效,从而影响计算结果。

torch.no_grad

torch.no_grad 是 PyTorch 提供的一个上下文管理器,它的作用是在代码执行时禁止计算梯度。这样可以减少不必要的计算,提高执行效率。例如

import torch

torch.manual_seed(0)
x = torch.randn(3, requires_grad=True)
with torch.no_grad():
    y = x * 2
print(y.requires_grad)  # False

torch.no_grad() 作用域下计算 y = x * 2,由于梯度被禁用,y.requires_gradFalse,即 PyTorch 不会为 y 构建计算图。

Python map 的惰性求值

在 Python 中,map 是一个惰性计算的函数,它返回一个迭代器,并不会立即执行计算。例如

def square(x: int) -> int:
    print(f"Computing square of {x}")
    return x * x

numbers = [1, 2, 3]
squares = map(square, numbers)

print(squares)
print("-" * 10)
print(list(squares))

执行上述代码,将会输出

<map object at 0x7f10fdf17730>
----------
Computing square of 1
Computing square of 2
Computing square of 3
[1, 4, 9]

可以看到,squares 并不是一个列表,而是一个 map 类型的对象。此外,Computing square of ... 的在 ---------- 之后被打印,证明 square 函数在 map(square, numbers) 中没有被执行,而是在 list(squares) 中才被执行的。这是因为 map 对象只有在被遍历(例如 listsum)时才真正执行。

map 惰性求值导致 torch.no_grad 失效

结合 torch.no_gradmap,可能会出现意想不到的行为。例如

import torch

x = [
    torch.tensor(i, dtype=torch.float, requires_grad=True)
    for i in range(1, 10)
]
with torch.no_grad():
    y = map(lambda x: x * 2, x)
print(sum(y).requires_grad)  # True

为什么 y 仍然有梯度?

因为 map(lambda x: x * 2, x) 只创建了一个 map 对象,并没有真正执行 x * 2 计算。torch.no_grad() 作用域结束后,计算 sum(y) 时,x * 2 计算才真正发生。这时 torch.no_grad() 已经失效,因此 PyTorch 仍然会追踪梯度。

解决方案

如果希望 torch.no_grad 正确生效,可以强制 map 立即执行

with torch.no_grad():
    y = list(map(lambda x: x * 2, x))
print(sum(y).requires_grad)  # False

这样 ytorch.no_grad 作用域内就完成计算,确保 sum(y) 不会被追踪梯度。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

LutingWang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值