Traceback (most recent call last): File "C:\Users\唐添美.DESKTOP-95B881R\Desktop\MNIST实现\test.py", line 38, in <module> mynet = torch.load('MNIST_4_acc_0.9692999720573425.pth', map_location=torch.device('cpu'), weights_only=True ) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\唐添美.DESKTOP-95B881R\.conda\envs\python高级应用\Lib\site-packages\torch\serialization.py", line 1524, in load raise pickle.UnpicklingError(_get_wo_message(str(e))) from None _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint. (1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message. WeightsUnpickler error: Unsupported global: GLOBAL __main__.MyNet was not an allowed global by default. Please use `torch.serialization.add_safe_globals([__main__.MyNet])` or the `torch.serialization.safe_globals([__main__.MyNet])` context manager to allowlist this global if you trust this class/function. Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
时间: 2025-05-27 13:35:10 浏览: 20
### PyTorch 加载模型时出现 `UnpicklingError` 的原因分析
当在使用 PyTorch 中的 `torch.load()` 函数加载模型权重文件时遇到 `UnpicklingError` 错误,通常是因为序列化和反序列化的环境不一致所导致。具体来说,这可能涉及以下几个方面的原因:
1. **类定义缺失或不同**:如果保存模型时使用的自定义网络结构(如 `__main__.MyNet`),但在加载时该类未被正确定义或者其定义发生了变化,则会引发此错误[^1]。
2. **Python 版本差异**:不同的 Python 版本可能导致 Pickle 序列化协议存在兼容性问题。
3. **PyTorch 版本冲突**:保存模型与加载模型时使用的 PyTorch 版本不匹配也可能引起此类问题。
4. **weights_only 参数设置不当**:虽然设置了 `weights_only=True`,但如果实际存储的是整个模型状态而非仅参数字典,仍可能出现问题。
#### 解决方案一:确保类定义一致性
确认当前环境中已导入并正确实现了用于训练模型的相同神经网络架构类 MyNet 。例如,在脚本开头重新声明如下所示的网络结构:
```python
import torch.nn as nn
import torch.nn.functional as F
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return F.relu(self.fc(x))
```
#### 解决方案二:调整 weights_only 使用方式
即使指定了 `map_location='cpu'` 和 `weights_only=True` ,也需要验证原始存档确实只包含 state_dict 而非完整的 model object 。可以尝试通过以下方法单独提取参数字典来规避潜在风险:
```python
state_dict = torch.load('model.pth', map_location=torch.device('cpu'), weights_only=True)
new_model = MyNet()
new_model.load_state_dict(state_dict)
```
#### 解决方案三:升级/降级至统一版本
为了消除因框架迭代带来的不确定性因素影响,建议尽量保持开发测试阶段以及部署环节均采用同一套稳定版次组合配置[^2]。
---
### 示例代码展示如何安全处理这种情况下的模型恢复过程
下面给出一段综合考虑以上要点后的改进型实现逻辑供参考:
```python
try:
checkpoint = torch.load(filepath, map_location="cpu", weights_only=True)
except Exception as e:
print(f"Failed to load with error {e}. Trying alternative method...")
if isinstance(checkpoint, dict):
new_instance = MyNet()
try:
new_instance.load_state_dict(checkpoint['model_state_dict'])
except KeyError:
new_instance.load_state_dict(checkpoint) # 如果直接就是state_dict而不是嵌套形式
else:
raise ValueError("Checkpoint format unexpected.")
```
阅读全文
相关推荐
















