在具体使用pytorch框架进行训练的时候,发现实验室的服务器是多GPU服务器,因此需要在训练过程中,将网络参数都放入多GPU中进行训练。
正文开始:
涉及的代码为torch.nn.DataParallel,而且官方推荐使用nn.DataParallel而不是使用multiprocessing。官方代码文档如下:nn.DataParallel 教程文档如下:tutorial
torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
该函数实现了在module级别上的数据并行使用,注意batch size要大于GPU的数量。
参数 : module:需要多GPU训练的网络模型
device_ids: GPU的编号(默认全部GPU)
output_device:(默认是device_ids[0])
dim:tensors被分散的维度,默认是0
在代码文档中使用方法为: