介绍
GAN一直是深度学习比较火的一种模型,在各个领域都有应用,不论是在CV、NLP、还是AR、VR,GAN的加入都让他们更加立体生动。最近元宇宙这个概念被炒的比较火热,这其中GAN就发挥了巨大的作用。来看几组GAN的例(妹)子。
是心动呀~当然是对GAN的,下面让我们来看一下什么是GAN吧。
什么是GAN
生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来在复杂分布上无监督学习最具前景的方法之一。
在GAN模型中,一般存在两个模块:生成模型(Generative Model)和判别模型(Discriminative Model),两者的相互博弈和学习使得他们共同进步并最终产生出真假难分的目标。
生成器就好比一个制造假币的骗子,而判别器就像是检验假币的警察,在最开始的时候这个骗子制造的假币的质量是很差的,一眼就能被警察识别,这个时候骗子就需要提升自己的造假技术去骗过警察,随着骗子制造出的假币越来越像真的,警察就很难辨别出假币的真假,这时候警察也需要通过学习提升自己的鉴别能力,就这样警察和骗子不断的欺骗和学习,最终生成真假难辨的纸币。
GAN的网络框架
以生成图片为例子,Generator是一个生成网络,它接收一个随机噪声noise,通过这个噪声生成图片,记作Generator Data,Discriminator是一个判别网络,判别一张图片是不是真实的,它的输出是一个概率值,如果为1,就代表100%是真实的图片,如果为0就代表100%是假的图片。
生成器实现代码
以手写体数据集MNIST为例,生成网络的输入是一行正态分布的随机数,因此它的输入是一个长度为N的一维向量,输出是一个(28,28,1)维的图片。下面的代码基于Keras框架。
def build_generator(self):
# --------------------------------- #
# 生成器,输入一串随机数字
# --------------------------------- #
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
判别器实现代码
判别器的目的是根据输入图片判断出真假。因此它的是一个(28,28,1)维的图片,输出是0到1之间的数,1代表这个图片是真的,0代表这个图片是假的。
def build_discriminator(self):
# ----------------------------------- #
# 评价器,对输入进来的图片进行评价`在这里插入代码片`
# ----------------------------------- #
model = Sequential()
# 输入一张图片
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=