帮我写一下一个基于PyTorch的CycleGAN实现水下模糊图像增强的简化代码框架,其中生成器使用U-Net结构,判别器采用PatchGAN设计
时间: 2025-06-13 21:23:07 浏览: 18
<think>嗯,用户让我写一个基于PyTorch的CycleGAN的简化代码框架,用于水下模糊图像增强。生成器要用U-Net,判别器是PatchGAN。首先,我需要回忆CycleGAN的基本结构和组件。CycleGAN包括两个生成器和两个判别器,用于将图像从一个域转换到另一个域,这里应该是从模糊的水下图像到清晰的图像,以及反向的转换。但用户可能只需要单向的增强,不过按照CycleGAN的原始设计,可能需要保持循环一致性,所以还是需要两个生成器和判别器。
生成器部分要求使用U-Net结构。U-Net通常有编码器和解码器,中间有跳跃连接。我需要考虑如何用PyTorch实现U-Net。可能包括下采样层、上采样层,以及跳跃连接的拼接。编码器部分可能由多个卷积层组成,逐步减少空间维度,增加通道数。解码器部分则通过转置卷积或上采样加卷积来恢复尺寸,同时合并跳跃连接的特征。
判别器是PatchGAN,也就是全卷积网络,输出一个矩阵,每个元素对应输入图像的一个patch是否为真。这通常由多个卷积层堆叠而成,最后用Sigmoid激活。需要注意使用LeakyReLU,并且最后一层卷积的通道数可能为1,然后通过Sigmoid输出概率。
接下来是损失函数部分。CycleGAN需要对抗损失(生成器和判别器的对抗)、循环一致性损失,可能还有身份损失。对抗损失可以用BCELoss或者MSELoss(如LSGAN)。循环一致性损失是L1损失,计算原始图像和重建图像之间的差异。身份损失可能可选,但用户没有提到,可能暂时不考虑。
数据需要成对的吗?CycleGAN不需要成对数据,所以数据加载器应该能处理两个不同域的无配对图像。但用户的应用是水下图像增强,可能只需要一个方向的转换,不过为了Cycle结构,可能需要两个域的数据。
然后是训练循环部分。需要交替训练生成器和判别器。可能先训练判别器,然后生成器。同时,在生成器训练时要考虑对抗损失和循环损失的总和。
代码结构方面,需要定义生成器、判别器、损失函数、优化器、数据加载、训练循环等模块。
可能的代码框架步骤:
1. 导入必要的库:torch, nn, optim, DataLoader,可能还有transforms。
2. 定义U-Net生成器。编码器部分用Conv2d,步长2下采样,ReLU激活。中间可能有残差块?或者直接跳跃连接。解码器用转置卷积或上采样+卷积。跳跃连接将编码器的特征图与解码器对应层的特征图拼接。
例如,U-Net的结构可能有4层下采样,4层上采样,每层下采样后通道数翻倍,直到某个最大值,比如64->128->256->512。解码器则反过来,512->256->128->64,每次上采样后和编码器对应的特征图拼接。
3. 定义PatchGAN判别器。多个卷积层,每层步长2,通道数增加,激活函数用LeakyReLU(0.2)。最后一层卷积输出1通道,用Sigmoid。输入是图像,输出是NxN的矩阵,每个元素代表一个patch的真伪。
4. 定义损失函数:对抗损失用BCEWithLogitsLoss或者MSE,循环一致性用L1Loss,可能需要权重参数。
5. 初始化生成器G_A2B(模糊到清晰)和G_B2A(清晰到模糊),判别器D_A和D_B(判别真实A和真实B)。
6. 定义优化器,可能需要为生成器和判别器分别设置优化器,使用Adam。
7. 数据加载:需要两个数据加载器,分别加载域A(模糊)和域B(清晰)的图像。可能需要自定义数据集类,假设图像放在不同的文件夹中。
8. 训练循环:对于每个epoch,遍历两个数据加载器,获取一批A和B的图像。训练判别器D_A:用真实A的图像和生成的A图像(通过G_B2A生成的假A)来计算损失,类似地训练D_B。然后训练生成器:生成假B和假A,计算对抗损失(让判别器认为生成的图像为真),以及循环一致性损失(A->B->A应该接近原A,B->A->B接近原B)。
注意在训练判别器时,需要固定生成器的参数,避免梯度计算到生成器。所以在训练判别器时,可能需要设置requires_grad为False。
另外,可能需要使用目标标签平滑(比如真实标签用0.9代替1)来稳定训练。
代码中可能还需要定义权重初始化,比如用正态分布初始化参数。
测试阶段,只需要用生成器G_A2B将输入的模糊图像转换为清晰图像。
现在,将这些步骤转化为代码框架。需要注意简化,所以可能不需要完整的残差块,而是简单的卷积层。例如,U-Net的每个下采样块可能包含卷积、实例归一化、ReLU。上采样块用转置卷积或上采样后接卷积,同样实例归一化和ReLU,最后输出层用Tanh激活。
对于生成器的代码,可能如下:
class UNetGenerator(nn.Module):
def __init__(self):
super().__init__()
# 编码器
self.down1 = nn.Sequential(nn.Conv2d(3, 64, 4, 2, 1), nn.InstanceNorm2d(64), nn.ReLU())
self.down2 = nn.Sequential(nn.Conv2d(64, 128, 4, 2, 1), nn.InstanceNorm2d(128), nn.ReLU())
self.down3 = nn.Sequential(nn.Conv2d(128, 256, 4, 2, 1), nn.InstanceNorm2d(256), nn.ReLU())
self.down4 = nn.Sequential(nn.Conv2d(256, 512, 4, 2, 1), nn.InstanceNorm2d(512), nn.ReLU())
# 解码器
self.up1 = nn.Sequential(nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.InstanceNorm2d(256), nn.ReLU())
self.up2 = nn.Sequential(nn.ConvTranspose2d(512, 128, 4, 2, 1), nn.InstanceNorm2d(128), nn.ReLU())
self.up3 = nn.Sequential(nn.ConvTranspose2d(256, 64, 4, 2, 1), nn.InstanceNorm2d(64), nn.ReLU())
self.up4 = nn.Sequential(nn.ConvTranspose2d(128, 3, 4, 2, 1), nn.Tanh())
def forward(self, x):
# 下采样
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
# 上采样并拼接跳跃连接
u1 = self.up1(d4)
u1 = torch.cat([u1, d3], dim=1)
u2 = self.up2(u1)
u2 = torch.cat([u2, d2], dim=1)
u3 = self.up3(u2)
u3 = torch.cat([u3, d1], dim=1)
u4 = self.up4(u3)
return u4
不过可能这里各层的通道数需要调整,比如上采样层的输入应该考虑跳跃连接后的通道数。例如,在down4之后是512通道,up1是转置卷积,输出256通道,然后与down3的输出(256)拼接,得到512通道。这可能需要在转置卷积的输出通道数上调整,例如up1的输出应该是256,这样拼接后变为256+256=512?或者我可能在设计时需要考虑每一层的通道数是否正确。
可能这个例子中的通道数设计有问题,比如down3的输出是256,那么下采样到d3的尺寸是H/8 x W/8 x256。然后经过down4变成H/16 x W/16 x512。up1是转置卷积,输入512,输出256,尺寸变为H/8 x W/8 x256,然后和d3(H/8 x W/8 x256)拼接,得到512通道。这样up2的输入应该是512,转置卷积输出128,然后和d2(H/4 x W/4 x128)拼接,得到256通道。以此类推。
那在代码中,每个上采样层的输入通道数应该等于上一层输出的通道数加上跳跃连接的通道数。例如,up1的输入是512(d4经过转置卷积后变为256),然后拼接d3的256,所以总通道数为256+256=512,此时up2的转置卷积的输入通道数应该是512,输出128。这样设计是否正确?
可能我之前写的示例代码的通道数设置有问题,需要修正。比如:
down1: 3 ->64
down2:64->128
down3:128->256
down4:256->512
然后上采样:
up1的输入是512(d4经过下采样后的结果),转置卷积后输出256通道,尺寸变为H/8 x W/8 x256。然后拼接d3(此时是H/8 x W/8 x256),所以拼接后的通道数是256+256=512。因此,up2的输入应为512,转置卷积输出128。然后拼接d2(128),得到128+128=256通道。up3输入256,转置输出64,拼接d1的64,得到128。up4输入128,转置输出3。这可能需要调整各层的参数。
或者,可能每个上采样层在转置卷积之后,需要处理跳跃连接的合并,因此转置后的通道数应等于跳跃连接的通道数。例如:
假设编码器各层的输出通道为:64, 128, 256, 512
解码器的各层在转置卷积后的通道数应为:256, 128, 64, 3
这样,在up1时,转置后的通道是256,然后拼接编码器的256(d3),总通道数256+256=512。此时,下一层转置卷积的输入应该是512,但需要输出128,这可能就需要调整转置卷积的输入通道数。或者,可能我的设计有误,需要更仔细的通道数计算。
或者,可能采用更简单的结构,比如每个上采样块将通道数减半,同时尺寸翻倍。例如,在U-Net中,每个下采样块将尺寸减半,通道数翻倍;每个上采样块将尺寸加倍,通道数减半,并拼接对应下采样层的通道数。例如:
下采样层1: 3 → 64
下采样层2:64 →128
下采样层3:128→256
下采样层4:256→512
中间层:可能一个瓶颈层?
上采样层1:512→256,拼接下采样3的256 → 总通道数512
上采样层2:512→128,拼接下采样2的128 → 总通道数256
上采样层3:256→64,拼接下采样1的64 → 总通道数128
上采样层4:128→3
这样,每个上采样块的输入通道数需要处理拼接后的结果。例如,第一个上采样块接收512通道(中间层的输出),转置卷积后输出256,然后拼接下采样层3的256,得到512通道。因此,下一个上采样块的输入是512,转置卷积输出128,拼接下采样层2的128,得到256,依此类推。这样,每个转置卷积层的输入通道数需要与前一层的输出相匹配。
因此,在代码中,各层的定义可能如下:
编码器部分:
down1: Conv2d(3, 64, 4, 2, 1) → 输出64通道,尺寸减半
down2: Conv2d(64, 128, 4, 2, 1) → 128通道,尺寸减半
down3: Conv2d(128, 256, 4, 2, 1) →256通道,尺寸减半
down4: Conv2d(256, 512, 4, 2, 1) →512通道,尺寸减半
解码器部分:
up1: ConvTranspose2d(512, 256, 4, 2, 1) →输出256,尺寸加倍,然后拼接down3的256 →总通道512
up2: ConvTranspose2d(512, 128, 4, 2, 1) →输出128,拼接down2的128 →总256
up3: ConvTranspose2d(256, 64, 4, 2, 1) →输出64,拼接down1的64 →总128
up4: ConvTranspose2d(128, 3, 4, 2, 1) →输出3,尺寸与原图相同,最后用Tanh激活
这样,每个上采样层的输入通道数需要是前一层处理后的通道数。比如,up2的输入是512(来自up1的输出拼接down3后的512),所以ConvTranspose2d的输入通道是512,输出128。拼接后得到128+128=256,所以下一个转置的输入是256。
这样,生成器的结构应该正确。然后每个卷积层后面可能有InstanceNorm和ReLU,除了最后一层用Tanh。
判别器部分,PatchGAN的结构通常是多层的卷积,最后输出一个NxN的矩阵。例如,70x70 PatchGAN。每一层卷积的通道数逐步增加,而尺寸逐步减小。例如:
输入3通道图像,经过:
Conv2d(3, 64, 4, 2, 1) → LeakyReLU(0.2)
Conv2d(64, 128, 4, 2, 1) → InstanceNorm → LeakyReLU
Conv2d(128, 256, 4, 2, 1) → InstanceNorm → LeakyReLU
Conv2d(256, 512, 4, 1, 1) → InstanceNorm → LeakyReLU
最后,Conv2d(512, 1, 4, 1, 1) → Sigmoid
这样输出的特征图大小取决于输入尺寸。例如,假设输入是256x256,经过三次下采样(每次减半)到32x32,但最后一个卷积的步长是1,所以输出可能是30x30或保持尺寸。可能需要调整padding。或者使用更标准的PatchGAN结构,输出为比如70x70的patch,这取决于网络结构的设计。
不过这里可能简化判别器结构,确保输出是一个矩阵。例如:
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 128, 4, 2, 1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, True),
nn.Conv2d(128, 256, 4, 2, 1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, True),
nn.Conv2d(256, 512, 4, 1, 1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, True),
nn.Conv2d(512, 1, 4, 1, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
这样输入3通道,输出1通道的特征图,每个元素对应一个patch的真伪。例如,输入256x256,经过三次下采样(每次步长2),尺寸变为32x32,然后两次步长1的卷积,最终输出可能是30x30,但具体尺寸可能需要计算。或者,可能使用padding='same'来保持尺寸?不过PyTorch的Conv2d的padding参数是整数,所以可能需要调整。例如,当kernel_size=4,stride=2,padding=1时,输出尺寸为 (H_in + 2*padding - kernel_size)/stride + 1。例如,输入H=256,经过Conv2d(3,64,4,2,1) →输出H= (256 +2*1 -4)/2 +1 = (254)/2 +1= 127 +1=128?可能这里需要更精确的尺寸计算,但作为框架,可能先忽略具体尺寸,假设判别器能正确输出一个矩阵。
接下来是损失函数和优化器。对抗损失用BCELoss,循环一致性用L1Loss。可能需要设置lambda_cycle参数,如10,来加权循环损失。
优化器使用Adam,学习率可能设置为0.0002,betas=(0.5, 0.999)。
训练循环的大致步骤:
对于每批数据:
1. 加载真实A(模糊)和真实B(清晰)的图像。
2. 训练判别器D_A:
a. 生成假A:G_B2A(real_B)
b. 计算D_A对real_A的输出和假A的输出的损失,目标标签为1和0。
c. 反向传播更新D_A的参数。
3. 同样训练判别器D_B:
a. 生成假B:G_A2B(real_A)
b. 计算D_B对real_B的输出和假B的输出的损失。
c. 更新D_B的参数。
4. 训练生成器G_A2B和G_B2A:
a. 生成假B和假A,再重建A和B。
b. 计算对抗损失(D_B对假B的输出的损失,目标为1;同理D_A对假A的损失)
c. 计算循环损失:L1(real_A, reconstructed_A)和L1(real_B, reconstructed_B)
d. 总损失 = 对抗损失 + lambda_cycle * 循环损失
e. 反向传播更新生成器的参数。
在代码中,需要计算梯度,并分步优化。注意在训练判别器时,生成器的参数不需要更新,因此可能需要使用with torch.no_grad()来避免计算生成器的梯度。或者在训练生成器时,再次前向传播生成假图像。
此外,可能需要使用GAN的标签平滑技巧,比如将真实标签设为0.9,假标签设为0.1,以增加稳定性。
数据加载部分需要两个数据集,分别来自域A和域B。假设图像存放在不同的文件夹中,可以使用ImageFolder加载,并应用相应的预处理(调整大小、归一化等)。例如:
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset_A = ImageFolder('path/to/A', transform=transform)
dataset_B = ImageFolder('path/to/B', transform=transform)
loader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True)
loader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True)
在训练循环中,可能需要同时迭代两个数据加载器,使用zip,或者分别取批次。需要注意如果两个数据集的长度不同,可能需要使用itertools.cycle来处理其中一个加载器,或者每次取min(len(A), len(B))。
现在,把这些部分整合成代码框架。由于用户要求简化,可能不需要完整的数据加载和训练循环的详细代码,而是框架性的代码,关键部分如生成器、判别器、损失计算等。
可能存在的疑问:生成器的跳跃连接是否正确实现?在U-Net中,每个上采样层应该与对应的下采样层拼接。例如,在forward过程中,保存每个下采样层的输出,然后在解码时进行拼接。之前的示例代码可能需要修正,例如:
class UNetGenerator(nn.Module):
def __init__(self):
super().__init__()
# 下采样
self.down1 = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True)
)
self.down2 = nn.Sequential(
nn.Conv2d(64, 128, 4, 2, 1),
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True)
)
self.down3 = nn.Sequential(
nn.Conv2d(128, 256, 4, 2, 1),
nn.InstanceNorm2d(256),
nn.ReLU(inplace=True)
)
self.down4 = nn.Sequential(
nn.Conv2d(256, 512, 4, 2, 1),
nn.InstanceNorm2d(512),
nn.ReLU(inplace=True)
)
# 上采样
self.up1 = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, 2, 1),
nn.InstanceNorm2d(256),
nn.ReLU(inplace=True)
)
self.up2 = nn.Sequential(
nn.ConvTranspose2d(512, 128, 4, 2, 1), # 输入是512,因为拼接后的通道数
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True)
)
self.up3 = nn.Sequential(
nn.ConvTranspose2d(256, 64, 4, 2, 1), # 输入是拼接后的256
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True)
)
self.up4 = nn.Sequential(
nn.ConvTranspose2d(128, 3, 4, 2, 1), # 输入是拼接后的128
nn.Tanh()
)
def forward(self, x):
# 下采样并保存各层结果
d1 = self.down1(x) # 64, 128x128
d2 = self.down2(d1) # 128, 64x64
d3 = self.down3(d2) # 256, 32x32
d4 = self.down4(d3) # 512, 16x16
# 上采样并拼接
u1 = self.up1(d4) # 256, 32x32
u1 = torch.cat([u1, d3], dim=1) # 256+256=512, 32x32
u2 = self.up2(u1) # 128, 64x64
u2 = torch.cat([u2, d2], dim=1) # 128+128=256, 64x64
u3 = self.up3(u2) # 64, 128x128
u3 = torch.cat([u3, d1], dim=1) # 64+64=128, 128x128
u4 = self.up4(u3) # 3, 256x256
return u4
这里需要注意每个上采样层的输入通道数是否正确。例如,up2的输入是u1,其通道数是512(因为拼接了d3的256和up1输出的256),所以ConvTranspose2d的输入通道数应该是512,输出128。这样,up2的输出是128,拼接d2的128,得到256通道,作为up3的输入。这样设计是否正确?
假设输入图像是256x256x3,经过down1(Conv2d 4x4,stride2,padding1)后,尺寸变为128x128x64。down2到64x64x128,down3到32x32x256,down4到16x16x512。然后up1是转置卷积,输入512,输出256,kernel 4x4,stride2,padding1 → 输出尺寸变为32x32x256。然后拼接d3(32x32x256),得到32x32x512。接着up2的输入是512,转置卷积输出128,尺寸变为64x64x128,拼接d2的64x64x128,得到64x64x256。up3输入256,转置卷积输出64,尺寸变为128x128x64,拼接d1的128x128x64,得到128x128x128。up4输入128,转置卷积输出3,尺寸256x256x3,Tanh激活。这样是正确的。
所以这个生成器的结构应该可以正确运行。
判别器的代码如前所述,输出是一个特征图,比如对于输入256x256,经过三次下采样(每次减半)到32x32,然后两次步长1的卷积,可能最终输出30x30的特征图,每个元素对应一个patch的判断结果。
损失函数部分,对抗损失可以使用BCELoss,但需要注意判别器的输出可能是一个矩阵,所以需要将预测结果和目标标签扩展成相同的形状。例如,判别器输出形状是(batch_size, 1, 30, 30),而目标标签可以是相同形状的全1或全0张量。或者,可能使用MSE损失,如LSGAN,这可能更稳定。
例如:
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
在训练判别器时,对于真实图像,目标标签为1;假图像为0。生成器的对抗损失目标是让判别器输出1。
训练循环的大致代码:
for epoch in range(num_epochs):
for i, (real_A, real_B) in enumerate(zip(loader_A, loader_B)):
# 转换为CUDA张量
real_A = real_A.to(device)
real_B = real_B.to(device)
# 训练判别器D_A和D_B
# 生成假图像
fake_B = G_A2B(real_A)
fake_A = G_B2A(real_B)
# 训练D_A
optimizer_D_A.zero_grad()
pred_real = D_A(real_A)
loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
pred_fake = D_A(fake_A.detach()) # 阻止梯度传到生成器
loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
loss_D_A = (loss_D_real + loss_D_fake) * 0.5
loss_D_A.backward()
optimizer_D_A.step()
# 训练D_B同理
optimizer_D_B.zero_grad()
pred_real = D_B(real_B)
loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
pred_fake = D_B(fake_B.detach())
loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
loss_D_B = (loss_D_real + loss_D_fake) * 0.5
loss_D_B.backward()
optimizer_D_B.step()
# 训练生成器
optimizer_G.zero_grad()
# 对抗损失
pred_fake_B = D_B(fake_B)
loss_GAN_A2B = criterion_GAN(pred_fake_B, torch.ones_like(pred_fake_B))
pred_fake_A = D_A(fake_A)
loss_GAN_B2A = criterion_GAN(pred_fake_A, torch.ones_like(pred_fake_A))
# 循环一致性损失
reconstructed_A = G_B2A(fake_B)
loss_cycle_A = criterion_cycle(reconstructed_A, real_A)
reconstructed_B = G_A2B(fake_A)
loss_cycle_B = criterion_cycle(reconstructed_B, real_B)
total_loss = loss_GAN_A2B + loss_GAN_B2A + lambda_cycle * (loss_cycle_A + loss_cycle_B)
total_loss.backward()
optimizer_G.step()
这里需要注意,生成器优化器应该同时优化G_A2B和G_B2A的参数,所以可能将它们放在同一个优化器中,或者分别设置。假设G_A2B和G_B2A是两个独立的生成器,可能需要为它们分别定义优化器,或者合并参数。
在代码中,可能将两个生成器的参数合并到同一个优化器中:
optimizer_G = torch.optim.Adam(
list(G_A2B.parameters()) + list(G_B2A.parameters()),
lr=lr, betas=(0.5, 0.999)
)
同样,判别器D_A和D_B各自的优化器。
最后,用户可能需要一个简化的代码框架,所以可能不需要完整的数据加载和训练循环,而是关键部分的代码。
综上,整合后的代码框架可能如下:</think>以下是一个基于PyTorch的CycleGAN简化代码框架,用于水下模糊图像增强。生成器使用U-Net结构,判别器采用PatchGAN设计:
```python
import torch
import torch.nn as nn
import torch.optim as optim
# 生成器U-Net结构
class UNetGenerator(nn.Module):
def __init__(self):
super().__init__()
# 编码器 (下采样)
self.down1 = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1), nn.InstanceNorm2d(64), nn.ReLU(inplace=True))
self.down2 = nn.Sequential(
nn.Conv2d(64, 128, 4, 2, 1), nn.InstanceNorm2d(128), nn.ReLU(inplace=True))
self.down3 = nn.Sequential(
nn.Conv2d(128, 256, 4, 2, 1), nn.InstanceNorm2d(256), nn.ReLU(inplace=True))
self.down4 = nn.Sequential(
nn.Conv2d(256, 512, 4, 2, 1), nn.InstanceNorm2d(512), nn.ReLU(inplace=True))
# 解码器 (上采样)
self.up1 = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.InstanceNorm2d(256), nn.ReLU(inplace=True))
self.up2 = nn.Sequential(
nn.ConvTranspose2d(512, 128, 4, 2, 1), nn.InstanceNorm2d(128), nn.ReLU(inplace=True))
self.up3 = nn.Sequential(
nn.ConvTranspose2d(256, 64, 4, 2, 1), nn.InstanceNorm2d(64), nn.ReLU(inplace=True))
self.up4 = nn.Sequential(
nn.ConvTranspose2d(128, 3, 4, 2, 1), nn.Tanh())
def forward(self, x):
d1 = self.down1(x) # 64x128x128
d2 = self.down2(d1) # 128x64x64
d3 = self.down3(d2) # 256x32x32
d4 = self.down4(d3) # 512x16x16
u1 = self.up1(d4) # 256x32x32
u1 = torch.cat([u1, d3], 1) # 512x32x32
u2 = self.up2(u1) # 128x64x64
u2 = torch.cat([u2, d2], 1) # 256x64x64
u3 = self.up3(u2) # 64x128x128
u3 = torch.cat([u3, d1], 1) # 128x128x128
return self.up4(u3) # 3x256x256
# 判别器PatchGAN
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 128, 4, 2, 1), nn.InstanceNorm2d(128), nn.LeakyReLU(0.2, True),
nn.Conv2d(128, 256, 4, 2, 1), nn.InstanceNorm2d(256), nn.LeakyReLU(0.2, True),
nn.Conv2d(256, 512, 4, 1, 1), nn.InstanceNorm2d(512), nn.LeakyReLU(0.2, True),
nn.Conv2d(512, 1, 4, 1, 1), nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 初始化网络
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G_A2B = UNetGenerator().to(device) # 模糊→清晰
G_B2A = UNetGenerator().to(device) # 清晰→模糊
D_A = Discriminator().to(device) # 判别模糊图像
D_B = Discriminator().to(device) # 判别清晰图像
# 损失函数和优化器
criterion_GAN = nn.BCELoss() # 对抗损失
criterion_cycle = nn.L1Loss() # 循环一致性损失
lambda_cycle = 10
optimizer_G = optim.Adam(
list(G_A2B.parameters()) + list(G_B2A.parameters()),
lr=0.0002, betas=(0.5, 0.999)
)
optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 训练循环框架
def train(loader_A, loader_B, num_epochs):
for epoch in range(num_epochs):
for batch_idx, (real_A, real_B) in enumerate(zip(loader_A, loader_B)):
real_A = real_A.to(device)
real_B = real_B.to(device)
# 训练判别器
fake_B = G_A2B(real_A)
fake_A = G_B2A(real_B)
# 训练D_A
optimizer_D_A.zero_grad()
loss_D_A_real = criterion_GAN(D_A(real_A), torch.ones_like(D_A(real_A)))
loss_D_A_fake = criterion_GAN(D_A(fake_A.detach()), torch.zeros_like(D_A(fake_A)))
loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5
loss_D_A.backward()
optimizer_D_A.step()
# 训练D_B(同上逻辑)
# 训练生成器
optimizer_G.zero_grad()
# 对抗损失
loss_GAN_A2B = criterion_GAN(D_B(fake_B), torch.ones_like(D_B(fake_B)))
loss_GAN_B2A = criterion_GAN(D_A(fake_A), torch.ones_like(D_A(fake_A)))
# 循环一致性损失
reconstructed_A = G_B2A(fake_B)
loss_cycle_A = criterion_cycle(reconstructed_A, real_A)
reconstructed_B = G_A2B(fake_A)
loss_cycle_B = criterion_cycle(reconstructed_B, real_B)
total_loss = loss_GAN_A2B + loss_GAN_B2A + lambda_cycle * (loss_cycle_A + loss_cycle_B)
total_loss.backward()
optimizer_G.step()
```
### 关键说明:
1. **数据准备**:需要准备两个数据集文件夹,分别存放模糊水下图像和清晰目标图像,使用`ImageFolder`加载数据
2. **训练技巧**:
- 建议使用LSGAN(最小二乘GAN)替代普通GAN以获得更好效果
- 可以添加身份损失(identity loss)提高色彩保真度
- 使用学习率衰减策略改善收敛
3. **部署应用**:训练完成后,使用`G_A2B`生成器对水下模糊图像进行增强:
```python
enhanced_img = G_A2B(blurred_img)
```
注意:这是一个简化框架,实际应用需要调整超参数、添加数据增强、实现完整训练逻辑等。建议参考官方CycleGAN实现补充细节。
阅读全文