基于GAN的风格迁移:从配对到非配对图像翻译
立即解锁
发布时间: 2025-09-05 01:41:22 阅读量: 12 订阅数: 26 AIGC 


生成式AI实战:Python与PyTorch
### 基于GAN的风格迁移:从配对到非配对图像翻译
#### 1. 配对图像翻译与pix2pix GAN
在图像到图像的翻译任务中,配对图像翻译是一种强大的方法。它能在给定源域和目标域数据集对的情况下,实现跨域转移。pix2pix GAN架构在这个任务中表现出色,它利用编码器 - 解码器架构开发生成器,能够产生高保真的输出。
pix2pix的创新之处在于使用了U - Net风格的生成器,通过跳过连接增强了信息传递,同时引入了PatchGAN判别器,为不同风格迁移用例提供了更好的反馈信号。通过实验,我们可以从零开始构建并训练pix2pix GAN,将卫星图像转换为类似谷歌地图的输出,且仅需少量训练样本和迭代就能获得高质量的结果。
#### 2. CycleGAN:非配对风格迁移的解决方案
配对风格迁移虽然强大,但受限于配对数据集的可用性。在许多实际场景中,很难获取配对的训练样本。CycleGAN应运而生,它通过放宽输入和输出图像的约束,探索了非配对风格迁移的范式。
CycleGAN的训练数据集由源集和目标集的非配对样本组成,模型尝试学习源域和目标域之间的风格差异,而无需输入和输出图像的显式配对。为了减少搜索空间并确保学习到的映射足够好,CycleGAN引入了循环一致性的概念。
假设存在两个生成器G和F,在理想情况下,它们应该是彼此的逆映射且是双射。通过同时训练这两个生成器,并施加循环一致性约束,使得 $F(G(x)) \approx x$ 和 $G(F(y)) \approx y$,从而实现非配对风格迁移GAN的成功训练。
#### 3. CycleGAN的组件与损失函数
- **生成器和判别器**:CycleGAN使用两组生成器和判别器,分别为G和 $D_Y$ 、F和 $D_X$。这种生成器 - 判别器对的设置,使得模型能够在正向和反向都学习到最佳的翻译映射。
- **对抗损失**:与典型GAN不同,CycleGAN需要对对抗损失进行调整。对于第一组生成器 - 判别器(G和 $D_Y$),对抗损失定义为:
- $\mathcal{L}_{GAN}(G, X, Y) = \min_G \max_{D_Y} V(G, D_Y, X, Y)$
- $\Rightarrow \mathcal{L}_{GAN} = E_{y \sim p_{data}(y)}[\log D_Y(y)] + E_{x \sim p_{data}(x)}[\log(1 - D_Y(G(x)))]$
对于第二组(F和 $D_X$),有类似的定义。同时,为了提高稳定性和输出质量,采用了最小二乘损失。
- **循环损失**:循环一致性损失是CycleGAN的核心创新之一。它通过确保经过生成器处理后的图像能够尽可能接近原始输入,减少了搜索空间。整体循环一致性损失是一个L1损失,定义为:
- $\mathcal{L}_{cyc}(G, F) = E_{x \sim p_{data}(x)}[\|F(G(x)) - x\|_1] + E_{y \sim p_{data}(y)}[\|G(F(y)) - y\|_1]$
- **身份损失**:在处理彩色对象时,CycleGAN的生成器可能会引入不必要的色调变化。为了减少这种不良行为,引入了身份损失。身份损失定义为:
- $\mathcal{L}_{idt}(G, F) = E_{x \sim p_{data}(x)}[\|F(x) - x\|_1] + E_{y \sim p_{data}(y)}[\|G(y) - y\|_1]$
- **总体损失**:CycleGAN的总体目标是对抗损失、循环一致性损失和身份损失的加权和,定义为:
- $\mathcal{L}_{total}(G, F, D_X, D_Y) = \mathcal{L}_{GAN}(G, D_Y, X, Y) + \mathcal{L}_{GAN}(F, D_X, Y, X) + \lambda \mathcal{L}_{cyc}(G, F) + \eta \mathcal{L}_{idt}(G, F)$
#### 4. CycleGAN的实现步骤
- **生成器设置**:CycleGAN使用U - Net生成器,与pix2pix不同的是,它使用实例归一化代替批量归一化。以下是上采样和下采样块的代码实现:
```python
# Upsampling Block
class UpSampleBlock(nn.Module):
def __init__(self, input_channels, output_channels):
super(UpSampleBlock, self).__init__()
layers = [
nn.ConvTranspose2d(
input_channels,
output_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False),
]
layers.append(nn.InstanceNorm2d(output_channels))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Dropout(0.5))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_connection):
x = self.model(x)
x = torch.cat((x, skip_connection), 1)
return x
# Downsampling block
class DownSampleBlock(nn.Module):
def __init__(self, input_channels, output_channels, normalize=True):
super(DownSampleBlock, self).__init__()
layers = [
nn.Conv2d(
input_channels,
output_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False)
]
if normalize:
layers.
```
0
0
复制全文
相关推荐










