单机多卡训练
时间: 2025-03-18 09:02:10 浏览: 45
### 单机多卡训练的配置与实现
#### TensorFlow 1.x 的单机多卡训练
在 TensorFlow 1.x 中,可以通过 `tf.estimator` 和 `tf.train.MonitoredTrainingSession` 来实现单机多卡训练。官方提供了一个脚本作为参考[^1]。主要思路是通过设置设备分配策略来利用多个 GPU 并行处理数据。
以下是基于 TensorFlow 1.x 的简单实现框架:
```python
import tensorflow as tf
# 定义输入管道
def input_fn():
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = dataset.shuffle(buffer_size).batch(batch_size)
return dataset
# 构建模型函数
def model_fn(features, labels, mode):
predictions = ...
loss = ...
optimizer = tf.compat.v1.train.AdamOptimizer()
# 使用分布式策略
with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):
train_op = optimizer.minimize(loss)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
# 设置集群规格
cluster_spec = tf.train.ClusterSpec({
'worker': ['localhost:2222', 'localhost:2223']
})
# 初始化 Estimator
config = tf.estimator.RunConfig(
session_config=tf.ConfigProto(log_device_placement=True),
train_distribute=tf.contrib.distribute.MirroredStrategy())
estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)
# 开始训练
estimator.train(input_fn=input_fn)
```
#### PyTorch 的单机多卡训练
PyTorch 提供了更灵活的方式来进行单机多卡训练,通常使用 `torch.nn.DataParallel` 或者更为推荐的 `torch.nn.DistributedDataParallel`[^3]。为了确保不同进程间同步初始化完成,可以调用 `torch.distributed.barrier()` 函数。
下面是一个简单的 PyTorch 多卡训练示例:
```python
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# 初始化进程组
dist.init_process_group(backend='nccl')
# 加载模型到指定设备上
device = torch.device(f'cuda:{rank}')
model = YourModel().to(device)
# 将模型封装成分布式模式
ddp_model = DDP(model, device_ids=[rank])
# 数据加载器需配合 DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size_per_gpu,
sampler=train_sampler
)
# 训练循环
for epoch in range(num_epochs):
for data, target in train_loader:
output = ddp_model(data.to(device))
loss = criterion(output, target.to(device))
loss.backward()
optimizer.step()
optimizer.zero_grad()
# 同步屏障
dist.barrier()
```
#### Baichuan-2 的 Hybrid Shard Zero2 策略
Baichuan-2 提出了 hybrid shard zero2 方法,这是一种结合 ZeRO-2 和其他分片技术的混合策略[^2]。ZeRO-2 主要用于减少内存占用的同时保持较高的计算效率。Hybrid shard zero2 可能进一步优化了参数、梯度以及优化状态的分布方式,在大规模模型训练中有显著优势。
#### 配置文件中的 DataLoader 设定
对于大多数深度学习框架而言,调整 `dataloader` 是实现高效多卡训练的重要部分之一。例如,在 MMDetection 或 MMCV 等库中,可以在配置文件里定义如下内容[^4]:
```yaml
data:
samples_per_gpu: 4
workers_per_gpu: 2
train:
type: CustomDataset
ann_file: path/to/train.json
pipeline:
- type: LoadImageFromFile
- type: Normalize
```
上述 YAML 文件片段展示了如何设定每张显卡上的样本数 (`samples_per_gpu`) 和工作线程数量 (`workers_per_gpu`)。
---
###
阅读全文
相关推荐


















