Pytorch之模型加载/保存

本文详细介绍了PyTorch中两种模型保存方法:保存整个模型和仅保存模型参数。推荐使用后者,虽稍显繁琐但更灵活,适用于预训练和参数迁移。文章还提供了保存模型参数的具体步骤,包括如何保存模型状态字典、优化器状态字典、损失、epoch和其他自定义信息。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

pytorch保存模型有两种方法:

  1. 保存整个模型 (结构+参数)
  2. 只保存参数(官方推荐)

两者都是用torch.save(obj, dir)实现,这个函数的作用是将对象保存到磁盘中,它的内部是使用Python的pickle实现。
两种方法的区别其实就是obj参数的不同:前者的obj是整个model对象,后者的obj是从model里获取的存储了model参数的词典,推荐用第二种,虽然麻烦了一丁点,但是比较灵活,有利于实现预训练、参数迁移等操作。

保存整个模型

这种方法很简单,保存和加载就两行代码,和Python pickle包的用法是一样的,把model当作一个对象直接保存加载就行。

# 保存
model = Mymodel()
torch.save(model, path)
# 加载 
model = torch.load(path)

Note:PyTorch约定使用.pt或.pth后缀命名保存文件。

保存参数

重点介绍一下这种方法,一般训完一个模型之后我们不会单独只保存一个模型的参数,为了方便后续操作,比如恢复训、参数迁移等,我们会保存当前转态的一个快照,具体信息可以根据自己的需要,下面列出几个方面:

  • 模型参数
  • 优化器参数
  • loss
  • epoch
  • args

把这些信息用字典包装起来,然后保存即可。

这种方式保存的模型只是它的参数,所以我们在加载时需要先创建好模型,然后再把参数加载进去,如下:

# 获得保存信息
save_data = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    'epoch': epoch,
    'args': args
     ...
}
# 保存
torch.save(save_data , path)
load_data = torch.load(path)
model = Mymodel()
optimizer = Myoptimizer()
# 加载参数
model.load_state_dict(load_data ['model_state_dict'])
optimizer.load_state_dict(load_data ['optimizer_state_dict'])
...

Note:PyTorch约定使用.pt或.pth后缀命名保存文件。

### 如何在 PyCharm 中配置和运行深度学习模型 #### 配置 PyCharm 开发环境 为了能够在 PyCharm 中顺利开发 Python Keras 深度学习模型,首先需要确保已经安装了合适的版本的 PyCharm 并创建了一个新的项目[^1]。 #### 安装必要的依赖包 对于基于 Keras 的深度学习应用来说,在开始编写代码之前还需要确保所有必需的软件包都已经正确安装。这通常包括 TensorFlow 或者其他支持 Keras 后端运算的框架。可以通过 pip 工具来完成这些库的安装: ```bash pip install tensorflow keras numpy pandas matplotlib scikit-learn ``` #### 连接到远程服务器并设置部署 如果计划在一个远程服务器上执行训练过程,则可以利用 PyCharm 提供的功能来进行连接。具体操作如下: - **配置 Deployment**:通过 Tools -> Deployment 菜单选项进入相应的界面,按照提示输入 SSH 主机名、用户名以及其他必要信息。 - **配置使用服务器端的 Python 环境**:选择 Run/Debug Configurations 下拉菜单中的 Edit Configurations... ,点击左上方加号按钮添加一个新的 Remote Interpreter via Fabric/SFTP 项,并指定目标主机上的解释器路径[^4]。 #### 编写与调试代码 当一切准备就绪之后便可以在本地编辑源码文件了。值得注意的是,由于涉及到 GPU 加速计算资源分配等问题,实际启动脚本时可能需要用到特定形式的命令行指令,例如 `CUDA_VISIBLE_DEVICES=0 python train.py` 来指明要使用的显卡编号[^3]。 另外,在 PyCharm 内部也可以方便地进行断点调试工作,这对于排查复杂网络结构中存在的 bug 十分有帮助。只需简单地点选左侧边栏即可快速设定断点位置;而一旦程序暂停下来则能够查看当前变量状态甚至临时改变其取值以便进一步测试不同情况下的行为表现。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值