troch.distributed
时间: 2025-06-25 08:27:45 浏览: 1
### PyTorch 分布式训练概述
`torch.distributed` 是 PyTorch 提供的一个核心模块,用于支持高效的分布式训练。该模块通过多种方式实现模型和数据的并行化处理,从而显著提升大规模机器学习任务的性能。
#### 初始化分布式环境
在执行任何分布式操作前,必须调用 `torch.distributed.init_process_group` 函数来初始化分布式环境[^2]。此函数允许开发者指定通信后端(如 NCCL 或 Gloo)、初始化方法以及集群配置参数(例如世界大小 `world_size` 和当前进程排名 `rank`)。这些参数对于协调不同设备间的计算至关重要。
#### 启动工具
为了简化跨节点或多 GPU 训练流程,PyTorch 提供了辅助脚本 `torch.distributed.launch`[^1]。这一工具能够自动为每台参与训练的服务器创建多个子进程,并分配相应的资源给它们。具体来说,在每个节点上都会生成若干个独立运行的 Python 进程来进行分布式的梯度更新工作。
#### 数据并行 vs 模型并行
尽管早期版本中存在简单的封装类如 `torch.nn.DataParallel` 来完成基本形式的数据划分任务[^3],但由于其局限性和效率低下等问题已被逐渐淘汰替代。取而代之的是更先进的解决方案——DistributedDataParallel (DDP)[^4] 。相比前者仅限于单一主机内部的不同图形处理器之间共享负载外;后者不仅适用于本地场景同时也兼容远程互联网络环境下更大规模系统的协作运算需求。
以下是基于 DDP 实现简单 MNIST 图像分类器的例子:
```python
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.distributed import DistributedSampler
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(torch.nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = torch.nn.Linear(784, 200)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(200, 10)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
if __name__ == "__main__":
local_rank = int(os.environ["LOCAL_RANK"])
setup(local_rank, 4)
model = ToyModel().to(local_rank)
ddp_model = DDP(model, device_ids=[local_rank])
dataset = MNIST('./data', download=True, transform=ToTensor())
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)
for epoch in range(10):
sampler.set_epoch(epoch)
for data, target in dataloader:
output = ddp_model(data.to(local_rank).view(-1, 784))
loss = F.nll_loss(output, target.to(local_rank))
optimizer.zero_grad()
loss.backward()
optimizer.step()
cleanup()
```
阅读全文
相关推荐
















