pytorch使用记录(三) 多GPU训练

  在具体使用pytorch框架进行训练的时候,发现实验室的服务器是多GPU服务器,因此需要在训练过程中,将网络参数都放入多GPU中进行训练。

   正文开始:

   涉及的代码为torch.nn.DataParallel,而且官方推荐使用nn.DataParallel而不是使用multiprocessing。官方代码文档如下:nn.DataParallel   教程文档如下:tutorial

torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)

该函数实现了在module级别上的数据并行使用,注意batch size要大于GPU的数量。

参数 : module:需要多GPU训练的网络模型

device_ids: GPU的编号(默认全部GPU)

output_device:(默认是device_ids[0])

dim:tensors被分散的维度,默认是0

在代码文档中使用方法为: