cGAN网络的基本实现(Mnist数字集)

本文介绍条件生成对抗网络(CGAN)的基本原理及其在图像生成任务中的应用。CGAN通过引入额外的条件信息y,使得生成器能够生成具有特定属性的图片,提高了生成图像的可控性和目的性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

CGAN-条件GAN

原始GAN的缺点:

生成的图像是随机的,不可预测的,无法控制网络的输出特定的图片,生成目标不明确,可控性不强。

针对原始GAN不能生成具有特定属性的图片的问题

CGAN核心在于将属性信息融入生成器G和判别器D中,属性y可以使任何标签信息,例如图像的类别、人脸图像的面部表情等
CGAN中心思想是希望可以控制GAN生成的图片,而不是单出的随机生成图片。具体来说,Conditional GAN在生成器和判别器的输入中增加了额外的条件信息,生成器生成的图片只有足够真实且与条件相符,才能够通过判别器

模型部分

在判别器和生成器中都添加了额外信息y, y可以是类别标签或者其他类型的数据,可以将y作为一个额外的输入层丢入判别器和生成器
生成器中,z和y连在一起隐含表示,带约束条件这个简单直接的改进被证明非常有效。
论文作者在cGAN对于图像自动标注的多模态学习上的应用,在MIRFlickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。

缺陷:

cGAN生成的图像边缘模糊,分辨率低,但是为后面的pix2pixGAN和Cycle-GAN开拓了道路
cGAN将无监督学习转变为有监督学习

模型实现

由于引入类别y,将数据集中target字段值变为独热标签形式。(注意返回向量格式)
并在创建数据集时,写入target_transfrom字段中
注意:由于有y的存在,所以数据集加载过程中都需要完整的批次,加入drop_last字段值

def onehot(x, class_count=10):
    return torch.eye(class_count)[x, :]
real_dataset = datasets.MNIST("mnist", 
				train=True, 
				target_transform=onehot,
				transform=transforms.Compose([transforms.ToTensor(), 
         				transforms.Normalize(0.5, 0.5)]))
real_dataload = DataLoader(real_dataset,batch_size=64, shuffle=True, drop_last=True)
生成器

由于添加了类别标签,数据集只有10个类别,类别标签y大小也为[10]
噪声100长度向量通过线性扩大为128 * 7 * 7,y同理扩大为128 * 7 * 7,易于合并成[256, 7, 7]大小向量,之后和DCGAN模型同理。

class Generate(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.genModel1 = nn.Sequential(
            nn.Linear(100,  128 * 7 * 7 ),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
        )
        self.genModel2 = nn.Sequential(
            nn.Linear(10,  128 * 7 * 7 ),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
        )
        self.genModel3 = nn.Sequential(
            # [128, 28, 28]
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            # [64, 14, 14] 
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # [1, 28, 28]
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z, tar):
        z = self.genModel1(z)
        z = torch.reshape(z, (-1, 128, 7, 7))
        tar = self.genModel2(tar)
        tar = torch.reshape(tar, (-1, 128, 7, 7))
        input = torch.cat([z, tar], 1)
        input = self.genModel3(input)
        return input
判别器

判别器同样需要输入标签y,处理思想一样,将y和输入img放缩成相同大小向量,这里通过线性层将y扩大为28 * 28长度,reshape成28 * 28大小图片,将img和y合并为[2, 28, 28]放入判别器,之后和DCGAN处理相同

class Discriminator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(10, 1 * 28 * 28)
        self.dismodel1 = nn.Sequential(
             nn.Conv2d(2, 64, kernel_size=3, stride=2),
             nn.LeakyReLU(),
             nn.Dropout2d(),
             nn.Conv2d(64, 128, kernel_size=3, stride=2),
             nn.LeakyReLU(),
             nn.Dropout2d(),
             nn.BatchNorm2d(128),
             nn.Flatten(),
             nn.Linear(128 * 6 * 6,1),
             nn.Sigmoid()
        )
    def forward(self, img, tar):
        tar = self.linear(tar)
        tar = torch.reshape(tar,(-1, 1, 28, 28))
        input = torch.cat([img, tar], 1)
        input = self.dismodel1(input)
        return input
可视化

和DCGAN处理有点区别,需要将标签y作为参数放入

def generate_and_save_images(model, epoch, label_input, noise_input):
    predictions = np.squeeze(model(noise_input, label_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow((predictions[i] + 1)/2, cmap='gray')
        plt.axis('off')
    plt.savefig(f'./train_mnist/image_at_epoch_{format(epoch)}.png')
    plt.show()

优化函数及测试用例

gen = Generate()
dis = Discriminator()
gen_optim = optim.Adam(gen.parameters(), lr = 1e-4)
dis_optim = optim.Adam(dis.parameters(), lr = 1e-5)
loss = nn.BCELoss()
write = SummaryWriter('cGAN_log')
noise_input = torch.randn(16, 100)
label_input = torch.randint(0, 10,size=(16,))
label_input_onehot = onehot(label_input)

训练数据集,和GAN没啥区别就是添加了y标签进去

for epoch in range(100):
    train_step = 0
    for item in real_dataload:
        dis_loss = 0
        dis_optim.zero_grad()
        img,target = item
        img_dis = dis(img, target)
        img_dis_loss = loss(img_dis, torch.ones_like(img_dis))
        img_dis_loss.backward()

        z = torch.randn((64, 100))
        z_gen = gen(z, target)
        z_dis = dis(z_gen.detach(), target.detach())
        z_dis_loss = loss(z_dis, torch.zeros_like(z_dis))
        z_dis_loss.backward()
        dis_optim.step()

        dis_loss = z_dis_loss + img_dis_loss

        gen_optim.zero_grad()
        z_out = dis(z_gen, target)
        z_gen_loss = loss(z_out, torch.ones_like(z_out))
        z_gen_loss.backward()
        gen_optim.step()
        train_step += 1
        write.add_scalar("分类器", dis_loss, train_step)
        write.add_scalar("生产器", z_gen_loss, train_step)
    print(label_input)
    generate_and_save_images(gen, epoch, label_input_onehot, noise_input)

    if epoch > 20 and epoch % 20 == 0:
        torch.save(gen, f"gen_method_{epoch}.pth")
        torch.save(dis, f"dis_method_{epoch}.pth")
write.close()

训练结果

[3, 3, 5, 1, 3, 7, 6, 1, 9, 7, 5, 3, 9, 0, 6, 5]
epoch1
在这里插入图片描述

epoch40
在这里插入图片描述epoch80:
在这里插入图片描述
epoch 120
在这里插入图片描述损失变化:
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

鲨鱼狂飙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值