cyclegan的模型保存与加载

模型保存:

if epoch % 2 ==0:
    state = {'netG_A2B': netG_A2B.state_dict(),'netG_B2A': netG_B2A.state_dict(),'netD_A': netD_A.state_dict(),'netD_B':netD_B.state_dict(),
             'optimizer_G': optimizer_G.state_dict(), 'optimizer_D_A': optimizer_D_A.state_dict(),'optimizer_D_B': optimizer_D_B.state_dict(),'epoch': epoch}
    torch.save(state, f"saved_models/_epoch_{epoch + 1}.pth")

模型加载:

if opt.epoch != 0:
    # Load pretrained models
    log_dir = "saved_models/_epoch_{}.pth".format(opt.epoch)
    checkpoint = torch.load(log_dir)
    netG_A2B.load_state_dict(checkpoint['netG_A2B'])
    netG_B2A.load_state_dict(checkpoint['netG_B2A'])
    netD_A.load_state_dict(checkpoint['netD_A'])
    netD_B.load_state_dict(checkpoint['netD_B'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G'])
    optimizer_D_A.load_state_dict(checkpoint['optimizer_D_A'])
    optimizer_D_B.load_state_dict(checkpoint['optimizer_D_B'])
    epoch = checkpoint['epoch']
    start_epoch =epoch + 1
    print('加载 epoch {} 成功!'.format(epoch))

else:
    # Initialize weights
    start_epoch = epoch
    netG_A2B.apply(weights_init_normal)  # netG是我们给写的神经网络定义的类实例。apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。也就是说apply函数,会一层一层的去拜访Generator网络层。
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

多卡并行:

     os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6"
    
     USE_CUDA = torch.cuda.is_available()
     device = torch.device("cuda:0" if USE_CUDA else "cpu")

     netG_A2B = torch.nn.DataParallel(netG_A2B, device_ids=[0, 1, 2])
     netG_A2B.to(device)
     netG_B2A = torch.nn.DataParallel(netG_B2A, device_ids=[0, 1, 2])
     netG_B2A.to(device)
     netD_A = torch.nn.DataParallel(netD_A, device_ids=[0, 1, 2])
     netD_A.to(device)
     netD_B = torch.nn.DataParallel(netD_B, device_ids=[0, 1, 2])
     netD_B.to(device)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值