深度了解detach(),这一篇足矣!

1 detach()的详细用法

new_tensor = tensor.detach()
特点:
  • 返回的新张量与原始张量共享数据存储
  • 新张量的requires_grad=False
  • 原始张量的梯度计算不受影响

2 实战

for item in eval_dataloader:
    inputs = _prepare_input(item, device=args.device)

    with torch.no_grad():
        outputs = model(**inputs, return_dict=True)
        loss = outputs.loss
        loss_list.append(loss.detach().cpu().item())

        preds = torch.argmax(outputs.logits.cpu(), dim=-1).numpy()
        preds_list.append(preds)

        labels_list.append(inputs['labels'].cpu().numpy()) 

第7行使用detach()是因为:

  1. loss是在模型前向传播时计算得到的,默认会带有梯度信息
  2. 但我们只是想在评估阶段记录损失值,不需要保留梯度(因为评估时不更新模型参数)
  3. 如果不detach,整个计算图会一直保留在内存中,可能导致内存泄漏
2.1.cpu().item()的链式调用
  • .detach():断开计算图
  • .cpu():将张量从GPU移到CPU
  • .item():将单元素张量转为Python标量
2.2 no_grad()和detach()的双重保护

with torch.no_grad()是禁用范围内所有梯度计算,而仍然使用detach() 是一个防御性编程策略,确保即使在其他代码修改时,评估指标的计算也不会意外保留计算图。

2.3 loss_list.append(loss.detach())和loss_list.append(loss)的区别

场景

内存持有内容

Python引用关系

不detach

整个计算图(x→y→loss)

loss_list[0]lossyx

使用detach

仅存储最终数值

loss_list[0] → 纯数值

假如在进行评估时,把测试数据分成了10个批次,那么就会产生10个loss

如果直接 loss_list.append(loss)

    • 列表中将存储 10个带有完整计算图引用的 loss 张量
    • 每个 loss 都通过 grad_fn 回溯到模型参数,形成 10个独立的计算图分支
    • 内存中实际保存的不是梯度值,而是构建梯度所需的计算图结构
    • 反向传播时,这些计算图可以分别生成梯度(但评估阶段通常不会反向传播)

如果使用 loss_list.append(loss.detach())

    • 列表仅存储 10个纯数值(标量)
    • 原始计算图在每轮迭代后被及时释放

2.4 如果使用 loss_sum += loss 会不会累积梯度计算图?

answer:Yes

  1. 每次迭代时:
    • 每个loss都会创建一个新的计算图(因为每次model(data)都是独立的前向传播)
    • 当执行loss_sum += loss时,PyTorch会构建一个新的计算节点(加法操作)
  1. 最终loss_sum包含的内容:
    • 不是"多个完整梯度",而是一个动态生长的聚合计算图

计算图结构大致如下:

AddBackward
├── Loss1Backward (来自第1个batch)
│   └── ModelForward1
└── AddBackward
    ├── Loss2Backward (来自第2个batch) 
    │   └── ModelForward2
    └── AddBackward
        ├── ...
        └── LossNBackward (来自第N个batch)
            └── ModelForwardN

loss_sum += loss会创建一个不断扩展的聚合计算图,包含所有batch的计算历史

2.5 原始计算图什么时候会被释放?

虽然在使用 detach() 后原始计算图仍然存在,但如果没有其他代码持有 loss 的引用,Python的引用计数机制会很快回收原始计算图

例如在评估循环中:

for data in eval_dataloader:
    loss = model(data)  # 新loss覆盖旧loss
    loss_list.append(loss.detach())  # 只存值
    # 上一轮的loss计算图在此处已无引用,会被回收

如果坚持不detach:

loss_list.append(loss)  # 列表持续持有计算图
# 即使loss变量被覆盖,列表内的引用仍保持计算图存活

3 item() detach()的区别

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_func(outputs, targets)
        loss.backward()
        optimizer.step()
        
        print('loss:',loss)
        print('loss.item():',loss.item())
        print('loss.detach():',loss.detach())
        train_loss += loss.item()			# <----关键
loss: tensor(2.3391, device='cuda:0', grad_fn=<NllLossBackward>)
loss.item(): 2.3391051292419434
loss.detach(): tensor(2.3391, device='cuda:0')

很明显,loss.backward()在上面已经进行过了,下面去计算train_loss的时候就不要再带有梯度信息才合适。故有两种解决方案:

使用loss.detach()来获取不需要梯度回传的部分。

detach()通过重新声明一个变量,指向原变量的存放位置,但是requires_grad变为False。

使用loss.item()直接获得对应的python数据类型。

建议: 把除了loss.backward()之外的loss调用都改成loss.item()

4 参考文章

【item() detach()用法】神经网络训练显存越来越大的原因之一_loss.detach()-CSDN博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值