报错是这样的:Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
时间: 2025-05-31 09:52:14 浏览: 28
### 解决 `Weights only load failed` 错误的方法
在加载 PyTorch 的检查点文件时,如果遇到 `_pickle.UnpicklingError: Weights only load failed` 错误,通常是由于以下原因之一引起的:PyTorch 版本的变化、检查点的安全性设置以及模型类定义的缺失。以下是详细的解决方案。
---
#### 原因分析
1. **默认行为变化**
自 PyTorch 2.6 起,默认情况下 `torch.load` 函数启用了 `weights_only=True` 参数[^2]。此参数旨在防止任意代码执行的风险,但如果检查点文件包含非张量数据(例如自定义类或函数),则会触发该错误。
2. **安全性限制**
默认情况下,PyTorch 只允许加载有限类型的全局对象(如内置 Python 类型)。如果检查点中包含了未被显式授权的自定义类(例如 `ultralytics.nn.tasks.DetectionModel`),就会抛出 `Unsupported global` 错误[^2]。
3. **模型类定义丢失**
如果检查点文件是在另一个环境中创建的,并且其中存储了模型的结构信息(而非仅权重),那么当前运行环境必须能够识别这些模型类。否则即使禁用 `weights_only`,也会导致加载失败。
---
#### 解决方案
##### 方法一:禁用 `weights_only`
如果信任检查点来源,可以直接将 `weights_only` 设置为 `False` 来绕过这一限制。需要注意的是,这样做可能会带来一定的安全隐患,因为它允许加载任何序列化的 Python 对象。
```python
checkpoint = torch.load(file_path, map_location=torch.device('cpu'), weights_only=False)
```
这种方式适合于从可信来源获取的检查点文件。
##### 方法二:使用上下文管理器扩展安全范围
为了更安全地加载包含自定义类的检查点,可以使用 `torch.serialization.safe_globals()` 上下文管理器或将所需类添加到白名单中。
```python
from ultralytics.nn.tasks import DetectionModel
with torch.serialization.safe_globals([DetectionModel]):
checkpoint = torch.load(file_path, map_location=torch.device('cpu'), weights_only=True)
```
这种方法既保持了一定程度的安全性,又支持加载复杂的模型结构。
##### 方法三:重新定义模型类
如果无法修改加载选项,也可以通过手动重建模型类的方式来解决问题。具体做法如下:
1. 确保本地代码中有与检查点对应的模型定义;
2. 使用 `strict=False` 参数加载权重,忽略多余的键值对。
```python
model = MyCustomModel(...) # 替换为你自己的模型定义
checkpoint = torch.load(file_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict'], strict=False)
```
这种方式适用于需要严格控制模型结构的场景[^3]。
##### 方法四:升级或降级 PyTorch 版本
如果上述方法均不可行,考虑调整 PyTorch 的版本以匹配检查点文件的生成环境。例如,对于较老的检查点文件,可以暂时回退到低于 2.6 的版本进行加载[^1]。
---
### 总结
- 若完全信任检查点来源,可通过关闭 `weights_only` 实现快速加载。
- 需要在保障安全的前提下加载复杂模型时,推荐采用上下文管理器机制。
- 缺少必要模型定义的情况下,需先补充相关类后再尝试恢复权重。
---
阅读全文
相关推荐




