生成对抗网络(GAN)的图像生成技术解析
立即解锁
发布时间: 2025-09-05 01:41:21 阅读量: 14 订阅数: 30 AIGC 


生成式AI实战:Python与PyTorch
### 生成对抗网络(GAN)的图像生成技术解析
#### 1. 基础GAN训练
在GAN的基础训练中,我们使用如下代码更新判别器的损失:
```python
d_loss = (real_loss + fake_loss) / 2
# Update Discriminator loss
d_loss.backward()
optimizer_D.step()
print(f'Epoch: {epoch}/{N_EPOCHS}-Batch: {i}/{len(dataloader)}—D.loss:{d_loss.item():.4f},G.loss:{g_loss.item():.4f}')
batches_done = epoch * len(dataloader) + i
if batches_done % SAMPLE_INTERVAL == 0:
save_image(gen_imgs.data[:25], f"images/{batches_done}.png",
nrow=5,
normalize=True)
```
我们以64的批量大小对基础GAN进行约200个周期的训练。从训练不同阶段的模型输出可以明显看到样本质量逐步提升。不过,基础GAN的结果虽令人鼓舞,但仍有改进空间。
#### 2. 改进的GAN架构
##### 2.1 深度卷积生成对抗网络(DCGAN)
2016年Radford等人提出的DCGAN,除了关注卷积层外,还对GAN的输出进行了多方面改进。在DCGAN出现之前,输出图像的分辨率较为有限。DCGAN通过使用大于1的步幅移动卷积滤波器来提高图像分辨率,同时利用批量归一化稳定训练过程。
判别器模型方面,基于CNN的二分类器进行了改进,使用大于1的步幅对输入进行下采样,替代了池化层,同时依赖批量归一化和Leaky - RELU,并且没有全连接层。
生成器模型与基础GAN不同,仅需输入向量的维度,通过重塑和上采样层将向量转换为二维图像并提高分辨率,除输入层外没有全连接层。以下是构建DCGAN生成器模型的代码:
```python
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.init_size = IMG_DIM // 4
self.l1 = nn.Sequential(
nn.Linear(Z_DIM, 128 * self.init_size ** 2)
)
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, NUM_CHANNELS, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
```
DCGAN的训练循环与基础GAN相同,它能在更少的训练周期内生成所需输出,理论上能生成比基础GAN质量更高的图像。
##### 2.2 条件生成对抗网络(C - GAN)
普通GAN虽能从训练域生成逼真样本,但无法控制生成特定类别的样本。而C - GAN由Mirza等人在2014年提出,能提供生成特定类别样本所需的控制。
C - GAN的生成器模型根据所需输出的特定特征生成假样本,判别器不仅要区分真假样本,还要判断生成样本与条件特征是否匹配。在“Conditional Adversarial Networks”的工作中,使用类标签作为额外的条件输入到生成器和判别器模型中。
以下是C - GAN的生成器和判别器模型代码:
```python
# 生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.
```
0
0
复制全文
相关推荐










