BEGAN论文阅读

BEGAN是一种解决GAN训练问题的方法,它使用自编码器作为分类器,通过Wasserstein距离调整生成器与分类器的平衡。论文提出通过控制超参数γ和变量Kt来平衡分类器的两个目标,实现生成样本多样性和模型训练的稳定性。实验表明,γ值影响生成图像的多样性,较低γ值可能导致图像单一,而增加γ则能提升多样性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

论文:

Berthelot D, Schumm T, Metz L. BEGAN: Boundary Equilibrium Generative Adversarial Networks[J]. arXiv preprint arXiv:1703.10717, 2017.

1、介绍

GANs是一类方法,这些方法能够学习数据分布并且实现从该分部中采样的模型。GANs的架构基于两个函数:生成器G(z),它可以将取自均匀分布的样本z映射到数据分布;分类器D(x),用来测定样本x是否来自数据分布。通常,基于博弈理论,生成器和分类器可以通过交替训练来实现联合学习。

GANs可以产生较为清晰的图像,但它仍面临许多问题:1、通常非常难以训练;2、难以选择正确的超参数;3、难以控制生成样本的多样性;4、难以平衡分类器和生成器的收敛速度:分类器往往在训练开始就可以取得优势;4、容易遇到模型崩溃问题:模型只学到单一的图像模式。

2、相关工作

介绍了几种GANs的变体,包括DCGANs,EBGANs以及WGANs

3、主要方法

BEGANs将一个自编码器作为分类器,通过基于Wasserstein距离的损失来匹配自编码器的损失分布。采用神经网络结构,训练中添加额外的均衡过程来平衡生成器与分类器。

用下式计算自编码器在样本上的误差:

20170825160454716.png

BEGAN采用Wasserstein距离计算D在真实数据与生成数据损失分布之间的距离。

对于两个正态分布,二者间的Wasserstein 距离计算公式为(文中证明两个损失分布都服从正态分布):

这两个话题有些不相关,我可以分别介绍一下。 首先,tqdm 是一个 Python 进度条库,可以在命令行中显示循环进度,非常方便。 安装: ``` pip install tqdm ``` 使用: ```python from tqdm import tqdm for i in tqdm(range(100)): # do something ``` 接下来是 PyTorch BEGAN 的实现。BEGAN 是一种生成式对抗网络(GAN)的变体,它可以生成高质量的图像。 安装 PyTorch: ``` pip install torch torchvision ``` BEGAN 的 PyTorch 实现可以在 GitHub 上找到。这里提供一个简单的实例: ```python import torch import torch.nn as nn import torch.optim as optim from torch.autograd import Variable from torchvision import datasets, transforms from tqdm import tqdm # define the model class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # define the layers def forward(self, x): # define the forward pass class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() # define the layers def forward(self, x): # define the forward pass # define the loss function criterion = nn.BCELoss() # define the optimizer optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # prepare the data transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_loader = torch.utils.data.DataLoader( datasets.MNIST('data', train=True, download=True, transform=transform), batch_size=64, shuffle=True) # train the model for epoch in range(epochs): for i, (images, _) in enumerate(tqdm(train_loader)): # train the discriminator # train the generator # generate a sample image z = Variable(torch.randn(64, 100)) sample = generator(z) # save the sample image ``` 以上代码中,需要自己实现 Generator 和 Discriminator 的定义和 forward 方法。在训练过程中,需要分别训练 Generator 和 Discriminator,具体实现可以参考 BEGAN 论文中的算法。在循环中加入 tqdm,可以显示训练进度。最后,可以生成一张样本图片并保存。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值