本文是对EnlightenGAN 无监督低照度图像增强的代码解读,原文解读请看EnlightenGAN 无监督低照度图像增强。
1、原文概要
EnlightenGAN是一种无监督的暗光增强方案,主要运用了以下几个策略实现无监督增强。
- 全局 - 局部判别器:全局判别器处理整体光照,局部判别器确保局部区域的真实感,避免过曝或欠曝。
- 自正则化感知损失:利用预训练 VGG 模型约束低光输入与增强输出的特征距离,保留图像内容。
- 注意力引导 U-Net 生成器:使用输入图像的光照信息生成注意力图,指导网络重点增强暗区域,抑制亮区域。
2、代码结构
代码整体结构如下:
models 目录下主要是网络模型实现图像的生成和判别,train.py是训练脚本,predict.py是推理脚本。
3 、核心代码模块
核心代码位于networks.py文件中,主要定义了生成器、判别器网络以及相关的损失函数,以下是对该文件中核心代码的详细介绍:
define_G
- 功能:根据指定的参数定义生成器网络,生成器网络结构有多种结构可选,比如resnet、unet等,当然此论文中使用的是unet网络结构。
- 参数:
input_nc
:输入图像的通道数。output_nc
:输出图像的通道数。ngf
:生成器中第一个卷积层的滤波器数量。which_model_netG
:生成器的模型名称,可选值为'resnet_9blocks'
、'resnet_6blocks'
、'unet_128'
等。norm
:归一化类型,默认为'batch'
。use_dropout
:是否使用 Dropout 层,默认为False
。gpu_ids
:使用的 GPU 编号列表,默认为[]
。skip
:是否使用跳跃连接,默认为False
。opt
:其他选项,默认为None
。
- 返回值:生成器网络。
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], skip=False, opt=None):
netG = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_128':
netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt)
elif which_model_netG == 'unet_512':
netG = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt)
elif which_model_netG == 'sid_unet':
netG = Unet(opt, skip)
elif which_model_netG == 'sid_unet_shuffle':
netG = Unet_pixelshuffle(opt, skip)
elif which_model_netG == 'sid_unet_resize':
netG = Unet_resize_conv(opt, skip)
elif which_model_netG == 'DnCNN':
netG = DnCNN(opt, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) >= 0:
netG.cuda(device=gpu_ids[0])
netG = torch.nn.DataParallel(netG, gpu_ids)
netG.apply(weights_init)
return netG
define_D
- 功能:根据指定的参数定义判别器网络。
- 参数:
input_nc
:输入图像的通道数。ndf
:判别器中第一个卷积层的滤波器数量。which_model_netD
:判别器的模型名称,可选值为'basic'
、'n_layers'
、'no_norm'
等。n_layers_D
:判别器中卷积层的数量,默认为 3。norm
:归一化类型,默认为'batch'
。use_sigmoid
:是否在输出层使用 Sigmoid 函数,默认为False
。gpu_ids
:使用的 GPU 编号列表。patch
:是否使用 PatchGAN,默认为False
。
- 返回值:判别器网络。
def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], patch=False):
netD = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netD == 'basic':
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'no_norm':
netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'no_norm_4':
netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'no_patchgan':
netD = FCDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, patch=patch)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
if use_gpu:
netD.cuda(device=gpu_ids[0])
netD = torch.nn.DataParallel(netD, gpu_ids)
netD.apply(weights_init)
return netD
GANLoss
- 功能:定义 GAN 损失函数,可以选择使用 LSGAN 或普通的 GAN 损失,本文使用的是LSGAN。
- 方法:
get_target_tensor(self, input, target_is_real)
:根据输入和目标标签生成目标张量。__call__(self, input, target_is_real)
:计算损失。
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
UnetGenerator
- 功能:定义基于 U-Net 结构的生成器网络。
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], skip=False, opt=None):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
self.opt = opt
# currently support only input_nc == output_nc
assert(input_nc == output_nc)
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, opt=opt)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout, opt=opt)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer, opt=opt)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer, opt=opt)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer, opt=opt)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer, opt=opt)
if skip == True:
skipmodule = SkipModule(unet_block, opt)
self.model = skipmodule
else:
self.model = unet_block
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
SkipModule
- 功能:定义带有跳跃连接的模块。
class SkipModule(nn.Module):
def __init__(self, submodule, opt):
super(SkipModule, self).__init__()
self.submodule = submodule
self.opt = opt
def forward(self, x):
latent = self.submodule(x)
return self.opt.skip*x + latent, latent
UnetSkipConnectionBlock
- 功能:定义 U-Net 中的跳跃连接块。
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, opt=None):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
stride=2, padding=1)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if opt.use_norm == 0:
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
else:
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([self.model(x), x], 1)
NLayerDiscriminator
- 功能:定义基于 PatchGAN 的判别器网络。
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
NoNormDiscriminator
- 功能:定义不使用归一化层的判别器网络。
class NoNormDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]):
super(NoNormDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
# if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
# return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
# else:
return self.model(input)
PerceptualLoss
- 功能:原文中自正则化感知损失实现,其实就是一个感知损失,不过此论文作者是将输入和网络推理输出做感知损失的约束。
class PerceptualLoss(nn.Module):
def __init__(self, opt):
super(PerceptualLoss, self).__init__()
self.opt = opt
self.instancenorm = nn.InstanceNorm2d(512, affine=False)
def compute_vgg_loss(self, vgg, img, target):
img_vgg = vgg_preprocess(img, self.opt)
target_vgg = vgg_preprocess(target, self.opt)
img_fea = vgg(img_vgg, self.opt)
target_fea = vgg(target_vgg, self.opt)
if self.opt.no_vgg_instance:
return torch.mean((img_fea - target_fea) ** 2)
else:
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)
以上代码是此论文生成器和判别器网络的定义,以及相应的损失函数。
4、总结
代码实现核心的部分讲解完毕,总体来看该代码比较简单,主要是常见的unet网络结构和LSGAN方面的内容。
感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。