pytorch加载保存查看checkpoint文件

目录

  1. 保存方式和加载方式–3种
  2. 跨gpu和cpu保存加载
  3. 查看checkpoint文件内容
  4. 常见问题–多gpu

1. 保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)
# 保存
torch.save(model.state_dict(), PATH)
# 加载
model.load_state_dict(torch.load(PATH))
# 测试时不启用 BatchNormalization 和 Dropout
model.eval()
# 方式二:保存加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)
# 加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 测试时
model.eval()
# 或者训练时
model.train()

2. 跨gpu和cpu

# GPU上保存,CPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
model.load_state_dict(torch.load(PATH))
model.to(device)
# CPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
# 选择希望使用的GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  
model.to(device)

3. 查看checkpoint文件内容

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4. 常见问题

  1. 多gpu
    报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’
    原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样
解决1:加载时将module去掉
# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存module
torch.save(model.module.state_dict(), PATH)

参考链接:
https://www.jianshu.com/p/4905bf8e06e5
https://blog.csdn.net/LXYTSOS/article/details/90639524

PyTorch 中的 `.ckpt` 文件通常不是 PyTorch 的标准模型保存格式,而是 TensorFlow 中常用的检查点(Checkpoint文件。如果你有一个由 TensorFlow 导出的 `.ckpt` 文件,并想在 PyTorch加载它,你需要借助第三方库,如 `tf2onnx`、`pytorch-model-zoo` 或者直接使用 `tf-nightly` 来完成转换。 首先,你需要将 TensorFlow 模型转换为 ONNX 格式,因为 ONNX 是一种跨框架的模型交换格式,然后你可以使用 `torch.onnx.load` 函数加载 ONNX 模型。以下是基本步骤: 1. 安装必要的库(假设已经安装了 `torch`, `torchvision`, 和 `onnx`): ```bash pip install tf-nightly onnx ``` 2. 使用 TensorFlow 将模型转换为 ONNX: ```python import tensorflow as tf from tensorflow import keras # 加载 .ckpt 检查点 model = keras.models.load_model('path_to_your_ckpt_file') # 冻结并打包模型到静态图(如果需要) concrete_func = model.signatures['serving_default'] tf_representations = concrete_func.graph.as_graph_def() # 使用 tf2onnx 进行转换 import onnx onnx_model = onnx.convert.from_tensorflow(tf_representations) onnx.save(onnx_model, 'model.onnx') ``` 3. 现在你有了 `.onnx` 文件,可以加载PyTorch: ```python import torch from torch.onnx import load_from_string # 加载转换后的 ONNX 模型 onnx_string = open('model.onnx', 'rb').read() model_pytorch = load_from_string(onnx_string) ``` 请注意,由于转换过程中可能会丢失一些信息,转换后的 PyTorch 模型可能无法完美复现原始 TF 模型的所有功能。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值