CGAN-条件GAN
原始GAN的缺点:
生成的图像是随机的,不可预测的,无法控制网络的输出特定的图片,生成目标不明确,可控性不强。
针对原始GAN不能生成具有特定属性的图片的问题
CGAN核心在于将属性信息融入生成器G和判别器D中,属性y可以使任何标签信息,例如图像的类别、人脸图像的面部表情等
CGAN中心思想是希望可以控制GAN生成的图片,而不是单出的随机生成图片。具体来说,Conditional GAN在生成器和判别器的输入中增加了额外的条件信息,生成器生成的图片只有足够真实且与条件相符,才能够通过判别器
模型部分
在判别器和生成器中都添加了额外信息y, y可以是类别标签或者其他类型的数据,可以将y作为一个额外的输入层丢入判别器和生成器
生成器中,z和y连在一起隐含表示,带约束条件这个简单直接的改进被证明非常有效。
论文作者在cGAN对于图像自动标注的多模态学习上的应用,在MIRFlickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。
缺陷:
cGAN生成的图像边缘模糊,分辨率低,但是为后面的pix2pixGAN和Cycle-GAN开拓了道路
cGAN将无监督学习转变为有监督学习
模型实现
由于引入类别y,将数据集中target字段值变为独热标签形式。(注意返回向量格式)
并在创建数据集时,写入target_transfrom字段中
注意:由于有y的存在,所以数据集加载过程中都需要完整的批次,加入drop_last字段值
def onehot(x, class_count=10):
return torch.eye(class_count)[x, :]
real_dataset = datasets.MNIST("mnist",
train=True,
target_transform=onehot,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)]))
real_dataload = DataLoader(real_dataset,batch_size=64, shuffle=True, drop_last=True)
生成器
由于添加了类别标签,数据集只有10个类别,类别标签y大小也为[10]
噪声100长度向量通过线性扩大为128 * 7 * 7,y同理扩大为128 * 7 * 7,易于合并成[256, 7, 7]大小向量,之后和DCGAN模型同理。
class Generate(nn.Module):
def __init__(self) -> None:
super().__init__()
self.genModel1 = nn.Sequential(
nn.Linear(100, 128 * 7 * 7 ),
nn.ReLU(),
nn.BatchNorm1d(128 * 7 * 7),
)
self.genModel2 = nn.Sequential(
nn.Linear(10, 128 * 7 * 7 ),
nn.ReLU(),
nn.BatchNorm1d(128 * 7 * 7),
)
self.genModel3 = nn.Sequential(
# [128, 28, 28]
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
# [64, 14, 14]
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
# [1, 28, 28]
nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, z, tar):
z = self.genModel1(z)
z = torch.reshape(z, (-1, 128, 7, 7))
tar = self.genModel2(tar)
tar = torch.reshape(tar, (-1, 128, 7, 7))
input = torch.cat([z, tar], 1)
input = self.genModel3(input)
return input
判别器
判别器同样需要输入标签y,处理思想一样,将y和输入img放缩成相同大小向量,这里通过线性层将y扩大为28 * 28长度,reshape成28 * 28大小图片,将img和y合并为[2, 28, 28]放入判别器,之后和DCGAN处理相同
class Discriminator(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(10, 1 * 28 * 28)
self.dismodel1 = nn.Sequential(
nn.Conv2d(2, 64, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Dropout2d(),
nn.Conv2d(64, 128, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Dropout2d(),
nn.BatchNorm2d(128),
nn.Flatten(),
nn.Linear(128 * 6 * 6,1),
nn.Sigmoid()
)
def forward(self, img, tar):
tar = self.linear(tar)
tar = torch.reshape(tar,(-1, 1, 28, 28))
input = torch.cat([img, tar], 1)
input = self.dismodel1(input)
return input
可视化
和DCGAN处理有点区别,需要将标签y作为参数放入
def generate_and_save_images(model, epoch, label_input, noise_input):
predictions = np.squeeze(model(noise_input, label_input).detach().cpu().numpy())
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow((predictions[i] + 1)/2, cmap='gray')
plt.axis('off')
plt.savefig(f'./train_mnist/image_at_epoch_{format(epoch)}.png')
plt.show()
优化函数及测试用例
gen = Generate()
dis = Discriminator()
gen_optim = optim.Adam(gen.parameters(), lr = 1e-4)
dis_optim = optim.Adam(dis.parameters(), lr = 1e-5)
loss = nn.BCELoss()
write = SummaryWriter('cGAN_log')
noise_input = torch.randn(16, 100)
label_input = torch.randint(0, 10,size=(16,))
label_input_onehot = onehot(label_input)
训练数据集,和GAN没啥区别就是添加了y标签进去
for epoch in range(100):
train_step = 0
for item in real_dataload:
dis_loss = 0
dis_optim.zero_grad()
img,target = item
img_dis = dis(img, target)
img_dis_loss = loss(img_dis, torch.ones_like(img_dis))
img_dis_loss.backward()
z = torch.randn((64, 100))
z_gen = gen(z, target)
z_dis = dis(z_gen.detach(), target.detach())
z_dis_loss = loss(z_dis, torch.zeros_like(z_dis))
z_dis_loss.backward()
dis_optim.step()
dis_loss = z_dis_loss + img_dis_loss
gen_optim.zero_grad()
z_out = dis(z_gen, target)
z_gen_loss = loss(z_out, torch.ones_like(z_out))
z_gen_loss.backward()
gen_optim.step()
train_step += 1
write.add_scalar("分类器", dis_loss, train_step)
write.add_scalar("生产器", z_gen_loss, train_step)
print(label_input)
generate_and_save_images(gen, epoch, label_input_onehot, noise_input)
if epoch > 20 and epoch % 20 == 0:
torch.save(gen, f"gen_method_{epoch}.pth")
torch.save(dis, f"dis_method_{epoch}.pth")
write.close()
训练结果
[3, 3, 5, 1, 3, 7, 6, 1, 9, 7, 5, 3, 9, 0, 6, 5]
epoch1
epoch40
epoch80:
epoch 120
损失变化: