H:\soft\conda3\envs\chat311\python.exe "D:/soft/PyCharm 2024.3.3/plugins/python-ce/helpers/pycharm/_jb_pytest_runner.py" --path H:\data\PycharmProjects\my_cmd\AI_test\蒸馏\t4.py Testing started at 19:31 ... Launching pytest with arguments H:\data\PycharmProjects\my_cmd\AI_test\蒸馏\t4.py --no-header --no-summary -q in H:\data\PycharmProjects\my_cmd\AI_test\蒸馏 ============================= test session starts ============================= collecting ... t4.py:None (t4.py) t4.py:34: in <module> teacher_model = torch.load('teacher_model_full.pth') H:\soft\conda3\envs\chat311\Lib\site-packages\torch\serialization.py:1360: in load return _load( H:\soft\conda3\envs\chat311\Lib\site-packages\torch\serialization.py:1848: in _load result = unpickler.load() H:\soft\conda3\envs\chat311\Lib\site-packages\torch\serialization.py:1837: in find_class return super().find_class(mod_name, name) E AttributeError: Can't get attribute 'TeacherModel' on <module '__main__' from 'D:\\soft\\PyCharm 2024.3.3\\plugins\\python-ce\\helpers\\pycharm\\_jb_pytest_runner.py'> collected 0 items / 1 error !!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!! ========================= 1 warning, 1 error in 2.09s ========================= Process finished with exit code 2
时间: 2025-06-10 21:43:54 浏览: 7
### 解决方案
在使用 `torch.load` 加载模型时,出现 `AttributeError: Can't get attribute 'TeacherModel' on <module '__main__'>` 错误的原因通常是由于加载的模型类定义未被正确导入或识别。以下是解决该问题的具体方法:
#### 1. 确保模型类定义在同一模块中
如果保存模型和加载模型的代码位于不同的文件中,必须确保保存模型时所使用的模型类定义(如 `TeacherModel`)在加载模型的文件中也可以被正确导入。例如,假设 `TeacherModel` 定义在 `train.py` 文件中,则需要将 `train.py` 中的 `TeacherModel` 类导入到加载模型的文件中[^2]。
```python
# 在加载模型的文件中导入 TeacherModel 类
from train import TeacherModel
model = torch.load('model.pth')
```
#### 2. 使用自定义类实例化后再加载状态字典
另一种方法是先实例化模型类,然后加载模型的状态字典。这种方法可以避免直接加载整个模型对象时引发的错误[^3]。
```python
from train import TeacherModel
# 实例化模型
model = TeacherModel()
# 加载状态字典
model.load_state_dict(torch.load('model.pth'))
```
#### 3. 检查保存模型的方式
在保存模型时,推荐仅保存模型的状态字典而非整个模型对象。这样可以减少依赖于特定模块结构的问题[^1]。
```python
torch.save(model.state_dict(), 'model.pth')
```
#### 4. 调整测试框架的执行环境
如果问题发生在 Pytest 运行时,可能是因为 Pytest 的执行环境与训练脚本的环境不同。可以通过以下方式调整:
- **确保模块路径一致**:确认 `train.py` 和运行 Pytest 的脚本位于相同的 Python 模块路径下。
- **动态导入**:在测试脚本中动态导入模型类。
```python
import importlib
# 动态导入模型类
train_module = importlib.import_module('train')
TeacherModel = getattr(train_module, 'TeacherModel')
# 加载模型
model = TeacherModel()
model.load_state_dict(torch.load('model.pth'))
```
### 注意事项
- 如果模型类定义在主脚本(`__main__`)中,建议将其移至独立的模块文件中以避免加载时的命名空间问题。
- 确保保存和加载模型时使用的 PyTorch 版本一致,否则可能会引发兼容性问题。
```python
# 示例代码:完整加载流程
from train import TeacherModel
def load_model(path):
model = TeacherModel()
model.load_state_dict(torch.load(path))
return model
if __name__ == "__main__":
model = load_model('model.pth')
```
阅读全文
相关推荐


















