pytorch中模型的保存与加载
时间: 2025-02-10 17:13:17 浏览: 58
在PyTorch中,模型的保存与加载是非常常见的任务,因为这允许你在训练过程中暂停并恢复,以及在需要的时候分享模型。以下是模型保存和加载的基本步骤:
1. **保存模型**:
- 使用`torch.save()`函数可以保存整个模型,包括其参数和结构。例如:
```python
torch.save(model.state_dict(), 'model.pth')
```
或者你可以只保存模型的参数而不是整个模型:
```python
torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), 'model_params.pth')
```
2. **加载模型**:
- 如果你知道模型的具体结构,可以创建一个新的模型实例,并通过`.load_state_dict()`方法加载权重:
```python
new_model = MyModelClass()
new_model.load_state_dict(torch.load('model.pth'))
```
- 如果你想在不同的设备上加载(如从GPU迁移到CPU),记得调整`map_location`参数:
```python
new_model.load_state_dict(torch.load('model.pth', map_location='cpu'))
```
阅读全文
相关推荐



















