pytorch模型单机多卡分布式推理
时间: 2025-01-11 17:35:08 浏览: 95
### PyTorch 单机多卡分布式推理最佳实践
#### 初始化环境配置
为了实现高效的单机多卡分布式推理,在启动前需设置合适的环境变量并初始化进程组。通过 `torch.distributed.init_process_group` 函数可以完成这一操作,通常采用 NCCL 作为后端来支持 GPU 设备间的通信。
```python
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
# set device for this process
torch.cuda.set_device(rank)
```
#### 加载预训练模型
当加载用于推理的预训练权重文件时需要注意,由于之前提到的模块名称差异问题[^2],如果是在 DataParallel 或者 DistributedDataParallel 下保存的模型,则需要调整键名映射逻辑以匹配当前使用的封装方式。
```python
model = MyModel().to(rank)
if not isinstance(model, (DDP)):
model = DDP(model, device_ids=[rank])
checkpoint = torch.load('path_to_checkpoint.pth', map_location={'cuda:%d' % 0: 'cuda:%d' % rank})
new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
model.load_state_dict(new_state_dict)
```
#### 构建数据集与采样器
对于大规模的数据集处理,建议使用 `DistributedSampler` 来确保每个进程中只负责一部分样本子集,从而提高整体效率并减少内存占用。
```python
dataset = DatasetClass()
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size_per_gpu,
sampler=sampler,
num_workers=num_workers,
pin_memory=True
)
```
#### 执行推理过程
最后一步就是编写实际执行预测任务的核心代码片段。这里强调一点,即在多GPU场景下应当避免不必要的同步等待,尽可能让各个设备独立工作直到最终汇总结果阶段才做必要的交互。
```python
model.eval() # 设置为评估模式
with torch.no_grad():
for i, data in enumerate(data_loader):
inputs = data.to(rank)
outputs = model(inputs)
predictions = gather_predictions(outputs) # 自定义函数收集来自各节点的结果
save_or_process(predictions) # 对齐后的输出可进一步分析或存储
```
阅读全文
相关推荐

















