多卡训练的模型怎么加载?
时间: 2025-06-28 21:10:06 浏览: 3
### 加载多GPU训练的PyTorch模型
对于PyTorch,在使用多个GPU进行训练之后,保存和加载模型的方式会有所不同。当利用`torch.nn.DataParallel`或`torch.nn.parallel.DistributedDataParallel`来并行化计算时,状态字典中的键将会带有前缀'module.'。
为了正确加载这样的模型到单个GPU或是CPU环境中,可以采用如下方法:
```python
import torch
model = TheModelClass(*args, **kwargs) # 初始化架构相同的模型实例
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load(path_to_model, map_location=device)
# 创建新的有序字典,去掉原有的 "module."
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # 去除 `module.` 的部分
new_state_dict[name] = v
# 将参数加载至模型中
model.load_state_dict(new_state_dict)
model.to(device)
```
如果目标环境同样具备多个可用的GPU,则可以直接加载而无需处理键名[^2]。
### 加载多GPU训练的TensorFlow模型
在TensorFlow框架下,无论是Keras API还是Estimator接口创建的模型都可以方便地跨不同设备配置之间迁移。对于通过`tf.distribute.Strategy`实现分布式的模型而言,其保存机制已经考虑到了这一点,因此通常情况下只需调用简单的API即可完成加载操作。
```python
import tensorflow as tf
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.models.load_model(path_to_saved_model)
# 如果是在MirroredStrategy之外加载的话也可以正常工作,
# TensorFlow会自动适配当前硬件条件下的最优策略。
```
值得注意的是,这里假设所使用的版本为较新版本(如TF 2.x),并且采用了推荐的做法即使用`save_model()`函数保存整个模型而非仅限于权重文件[^1]。
阅读全文
相关推荐


















