模型保存:
if epoch % 2 ==0: state = {'netG_A2B': netG_A2B.state_dict(),'netG_B2A': netG_B2A.state_dict(),'netD_A': netD_A.state_dict(),'netD_B':netD_B.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D_A': optimizer_D_A.state_dict(),'optimizer_D_B': optimizer_D_B.state_dict(),'epoch': epoch} torch.save(state, f"saved_models/_epoch_{epoch + 1}.pth")
模型加载:
if opt.epoch != 0: # Load pretrained models log_dir = "saved_models/_epoch_{}.pth".format(opt.epoch) checkpoint = torch.load(log_dir) netG_A2B.load_state_dict(checkpoint['netG_A2B']) netG_B2A.load_state_dict(checkpoint['netG_B2A']) netD_A.load_state_dict(checkpoint['netD_A']) netD_B.load_state_dict(checkpoint['netD_B']) optimizer_G.load_state_dict(checkpoint['optimizer_G']) optimizer_D_A.load_state_dict(checkpoint['optimizer_D_A']) optimizer_D_B.load_state_dict(checkpoint['optimizer_D_B']) epoch = checkpoint['epoch'] start_epoch =epoch + 1 print('加载 epoch {} 成功!'.format(epoch)) else: # Initialize weights start_epoch = epoch netG_A2B.apply(weights_init_normal) # netG是我们给写的神经网络定义的类实例。apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。也就是说apply函数,会一层一层的去拜访Generator网络层。 netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal)
多卡并行:
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6"
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda:0" if USE_CUDA else "cpu")
netG_A2B = torch.nn.DataParallel(netG_A2B, device_ids=[0, 1, 2])
netG_A2B.to(device)
netG_B2A = torch.nn.DataParallel(netG_B2A, device_ids=[0, 1, 2])
netG_B2A.to(device)
netD_A = torch.nn.DataParallel(netD_A, device_ids=[0, 1, 2])
netD_A.to(device)
netD_B = torch.nn.DataParallel(netD_B, device_ids=[0, 1, 2])
netD_B.to(device)