1.BatchNorm2d归一化图像时所使用的均值(running_mean,batch_mean)和方差(running_var,batch_var)
BatchNorm2d在 eval 模式下归一化图像使用的就是训练过程中训练样本累计的running_mean和running_var。
然而,在train模式下:(1)BatchNorm2d归一化图像使用的均值和方差为当前批次数据的batch_mean和batch_var;(2)即便不使用反向传播,只要数据经过了模型,模型中BatchNorm2d的均值和方差就会按照(1-momentum) * running_mean + momentum*batch_var(默认momentum是0.1)进行更新。也就是说,看上去在train模式下你没有反向传播更新参数,但是每一个batch经过模型以后,BatchNorm2d的running_mean和running_var持续更新。不过这个更新对输出并不影响,如前面说的,train模式下图像归一化仅与批次数据的batch_mean和batch_var有关。
这事实上也就能够理解为什么很多人说:就BatchNorm2d而言,训练时使用大的batch_size是更好的。因为训练时使用batch_mean和batch_var归一化,而测试时使用running_mean和running_var归一化。训练时更大的batch_size使batch_mean和batch_var更接近于running_mean和running_var,因此训练时的结果和测试时的结果更加相近。另外,更大的batch_size使得训练时的batch_mean和batch_var更加稳定。
还值得注意的是,因为eval模式下使用running_mean和running_var,所以无论你的batch_size设置成多大,它的结果都是一样的。然而,如果你强行用train模式进行测试,此时使用的是batch_mean和batch_var。在batch_size设置较大时,batch_mean和batch_var和running_mean和running_var相近,性能接近于真实性能,而当batch_size设置较小时,batch_mean和batch_var和running_mean和running_var相差很大,那么性能就可能出现很大的问题。
2.多卡训练且模型中包含了BatchNorm2d
如果你使用DataParallel进行了多卡训练,同时训练的模型包含了BatchNorm2d层,这可能导致每张卡维护的running_mean和running_var是不同的,最后保存的时候通常只是保存了“当前主卡”(rank 0)上的那一份running_mean和running_var,这就导致了错误的发生。特别是你在测试的时候,你会发现train和eval的结果差距可能非常大。
为了避免这一情况,应使用DDP进行多卡训练,并在多卡训练时将BatchNorm2d转换成SyncBatchNorm。关于多卡DDP的使用,博客https://www.cnblogs.com/liyier/p/18136458有详细的说明。
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler
import traceback
def setup_ddp(rank, world_size)
dist.init_process_group(backend="nccl", init_method="env://")
torch.cuda.set_device(rank)
def cleanup_ddp():
dist.destroy_process_group()
def main():
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
setup_ddp(rank, world_size)
...
dataset = ... #就是正常加载数据集
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
dataloader = DataLoader(dataset, batch_size=64, sampler=sampler, numworkers=8) # 注意,这里不能再手动设置shuffle了,不然就会和sampler冲突报错
...
model = model.to(rank)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank], output_device=rank)
...
model.train()
for epoch in range(100):
sampler.set_epoch(epoch) # shuffle different on each epoch
for inputs, labels in dataloader:
inputs = inputs.to(rank)
labels = labels.to(rank)
if rank == 0:
torch.save(...)
cleanup_ddp()
if __name__ == "__main__":
try:
main()
except Exception as e:
print(f"[Rank {os.environ.get('RANK')}] error:")
traceback.print_exc()
raise e
运行时使用如下代码:
export OMP_NUM_THREADS=4
torchrun --nproc_per_node=2 --master_port=28999 train.py
要注意的是,这里的master_port建议设置20000到30000之间的数,如果报错端口被占用了,就换一个数。
总结起来,以上两个点是造成BatchNorm2d对train和eval模式下结果差距大的主要原因。如果能够理解这两个点,那么关于BatchNorm2d应该就不会有太大问题。