pytorch 错误解决:NotImplemented Error

文章讨论了在使用PyTorch时遇到NotImplementedError的常见原因,包括forward函数未定义、拼写错误或缩进问题,特别是当涉及ModuleList时。作者提到了一种特殊情况,即在集成包含ModuleList的预训练模型时,需要在forward函数中手动连接ModuleList的内容以避免错误。为了解决这个问题,文章提供了一个名为model2extractor的函数,该函数能遍历模型并把ModuleList添加到Sequential结构中。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

错误描述

File “”, line 1, in
runfile(‘C:/usr/local/Anaconda3/mylib/Autoencoder_test.py’, wdir=‘C:/usr/local/Anaconda3/mylib’)

File “C:\usr\local\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py”, line 705, in runfile
execfile(filename, namespace)

File “C:\usr\local\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py”, line 102, in execfile
exec(compile(f.read(), filename,exec), namespace)

File “C:/usr/local/Anaconda3/mylib/Autoencoder_test.py”, line 85, in
xhat = model(x)

File “C:\usr\local\Anaconda3\lib\site-packages\torch\nn\modules\module.py”, line 477, in call
result = self.forward(*input, **kwargs)

File “C:\usr\local\Anaconda3\lib\site-packages\torch\nn\modules\module.py”, line 83, in forward
raise NotImplementedError

NotImplementedError

问题分析

  • 通常这种情况下的问题有:
    • 定义自己的 Module 的时候 forward 函数没有实现 / forward 函数名字写错了 / forward 函数前后位置的缩进有空格或者缩进错误
    • 使用了 ModuleList 去存放自己的网络模块。但是没有在 forward 中将 ModuleList 中的模块进行连接

特殊情况

  • 还有一种非常容易疏忽的情况:

    • 在调用别人的预训练模型或者别人封装好的模块的时候,里面包含了 ModuleList 模块,但是你没有对齐中的内容进行连接。比如下面这种情况。
    • 目前我想要使用别人写好的 A 模型作为我的网络的一部分,但是 A 模型自身包含两部分,第一部分是 Sequential 定义好的模块(这个模块不需要我们重新通过 forward 进行连接),第二部分是 ModuleList
    • 那么当我们构建我们自己网络的时候,无论这个 ModuleList 藏得有多深,如果我们不在自己的 forward 函数中进行人为的连接,他都会报错 NotImplemented Error
    • 所以我用下面的代码将 A模型 中的部分进行遍历,如果遇到上述 ModuleList 的情况,就把他们加入到 Sequential 结构中
    def model2extractor(model):
        sequential = nn.Sequential()
        for name, module in model.named_children():
            print(name)
            if isinstance(module, nn.ModuleList):
                mod = nn.Sequential(*module)
                sequential.add_module(name, mod)
            else:
                sequential.add_module(name, module)
        return sequential
    
### 解决 PyTorch 中 `aten::scatter` 操作导致的 `RuntimeError: derivative not implemented` 错误 当遇到 `aten::scatter` 导致的导数未实现错误时,可以采取多种策略来解决问题。以下是几种可能的方法: #### 方法一:使用替代操作 如果 `scatter` 的具体功能可以通过其他可微分的操作实现,则考虑替换这些不可微分的部分。例如,在某些情况下,可以用索引赋值或其他张量运算代替 `scatter`。 ```python import torch def custom_scatter(input_tensor, dim, index, src): output = input_tensor.clone() output.scatter_(dim, index, src) return output ``` 这种方法允许手动创建一个类似的函数并确保其完全可微分[^1]。 #### 方法二:自定义 autograd 函数 对于更复杂的情况,编写自己的 `torch.autograd.Function` 来提供缺失的反向传播路径是一个有效的解决方案。这涉及到重写前向和后向过程中的逻辑。 ```python class ScatterFunction(torch.autograd.Function): @staticmethod def forward(ctx, input_tensor, dim, index, src): ctx.save_for_backward(index, src) result = input_tensor.scatter(dim=dim, index=index, src=src) return result @staticmethod def backward(ctx, grad_output): index, src = ctx.saved_tensors grad_input = grad_output.clone().zero_() grad_src = grad_output.gather(0, index).clone() scatter_grad = grad_input.scatter_add( 0, index.expand_as(grad_src), grad_src) return scatter_grad, None, None, None custom_scatter_op = ScatterFunction.apply ``` 通过这种方式,可以在不改变原有模型结构的前提下修复梯度计算问题[^2]。 #### 方法三:调整环境设置 有时禁用确定性算法可以帮助绕过特定类型的运行时错误。虽然这不是根本性的修正方案,但在调试阶段可能会有所帮助。 ```python torch.use_deterministic_algorithms(False) ``` 需要注意的是,此选项应谨慎使用,并仅限于开发环境中测试目的。 以上三种方式可以根据实际应用场景灵活选用或组合应用以克服由 `aten::scatter` 引起的问题。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

暖仔会飞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值