dcgan

来源:https://github.com/aymericdamien/TensorFlow-Examples#tutorials

""" Deep Convolutional Generative Adversarial Network (DCGAN).

Using deep convolutional generative adversarial networks (DCGAN) to generate
digit images from a noise distribution.

References:
    - Unsupervised representation learning with deep convolutional generative
    adversarial networks. A Radford, L Metz, S Chintala. arXiv:1511.06434.

Links:
    - [DCGAN Paper](https://arxiv.org/abs/1511.06434).
    - [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).

Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
"""

from __future__ import division, print_function, absolute_import

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

# Training Params
num_steps = 20000
batch_size = 32

# Network Params
image_dim = 784 # 28*28 pixels * 1 channel
gen_hidden_dim = 256
disc_hidden_dim = 256
noise_dim = 200 # Noise data points


# Generator Network
# Input: Noise, Output: Image
def generator(x, reuse=False):
    with tf.variable_scope('Generator', reuse=reuse):
        # TensorFlow Layers automatically create variables and calculate their
        # shape, based on the input.
        x = tf.layers.dense(x, units=6 * 6 * 128)
        x = tf.nn.tanh(x)
        # Reshape to a 4-D array of images: (batch, height, width, channels)
        # New shape: (batch, 6, 6, 128)
        x = tf.reshape(x, shape=[-1, 6, 6, 128])
        # Deconvolution, image shape: (batch, 14, 14, 64)
        x = tf.layers.conv2d_transpose(x, 64, 4, strides=2)
        # Deconvolution, image shape: (batch, 28, 28, 1)
        x = tf.layers.conv2d_transpose(x, 1, 2, strides=2)
        # Apply sigmoid to clip values between 0 and 1
        x = tf.nn.sigmoid(x)
        return x


# Discriminator Network
# Input: Image, Output: Prediction Real/Fake Image
def discriminator(x, reuse=False):
    with tf.variable_scope('Discriminator', reuse=reuse):
        # Typical convolutional neural network to classify images.
        x = tf.layers.conv2d(x, 64, 5)
        x = tf.nn.tanh(x)
        x = tf.layers.average_pooling2d(x, 2, 2)
        x = tf.layers.conv2d(x, 128, 5)
        x = tf.nn.tanh(x)
        x = tf.layers.average_pooling2d(x, 2, 2)
        x = tf.contrib.layers.flatten(x)
        x = tf.layers.dense(x, 1024)
        x = tf.nn.tanh(x)
        # Output 2 classes: Real and Fake images
        x = tf.layers.dense(x, 2)
    return x

# Build Networks
# Network Inputs
noise_input = tf.placeholder(tf.float32, shape=[None, noise_dim])
real_image_input = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])

# Build Generator Network
gen_sample = generator(noise_input)

# Build 2 Discriminator Networks (one from noise input, one from generated samples)
disc_real = discriminator(real_image_input)
disc_fake = discriminator(gen_sample, reuse=True)
disc_concat = tf.concat([disc_real, disc_fake], axis=0)

# Build the stacked generator/discriminator
stacked_gan = discriminator(gen_sample, reuse=True)

# Build Targets (real or fake images)
disc_target = tf.placeholder(tf.int32, shape=[None])
gen_target = tf.placeholder(tf.int32, shape=[None])

# Build Loss
disc_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits=disc_concat, labels=disc_target))
gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits=stacked_gan, labels=gen_target))

# Build Optimizers
optimizer_gen = tf.train.AdamOptimizer(learning_rate=0.001)
optimizer_disc = tf.train.AdamOptimizer(learning_rate=0.001)

# Training Variables for each optimizer
# By default in TensorFlow, all variables are updated by each optimizer, so we
# need to precise for each one of them the specific variables to update.
# Generator Network Variables
gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator')
# Discriminator Network Variables
disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator')

# Create training operations
train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)

# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()

# Start training
with tf.Session() as sess:

    # Run the initializer
    sess.run(init)

    for i in range(1, num_steps+1):

        # Prepare Input Data
        # Get the next batch of MNIST data (only images are needed, not labels)
        batch_x, _ = mnist.train.next_batch(batch_size)
        batch_x = np.reshape(batch_x, newshape=[-1, 28, 28, 1])
        # Generate noise to feed to the generator
        z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])

        # Prepare Targets (Real image: 1, Fake image: 0)
        # The first half of data fed to the generator are real images,
        # the other half are fake images (coming from the generator).
        batch_disc_y = np.concatenate(
            [np.ones([batch_size]), np.zeros([batch_size])], axis=0)
        # Generator tries to fool the discriminator, thus targets are 1.
        batch_gen_y = np.ones([batch_size])

        # Training
        feed_dict = {real_image_input: batch_x, noise_input: z,
                     disc_target: batch_disc_y, gen_target: batch_gen_y}
        _, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss],
                                feed_dict=feed_dict)
        if i % 100 == 0 or i == 1:
            print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))

    # Generate images from noise, using the generator network.
    f, a = plt.subplots(4, 10, figsize=(10, 4))
    for i in range(10):
        # Noise input.
        z = np.random.uniform(-1., 1., size=[4, noise_dim])
        g = sess.run(gen_sample, feed_dict={noise_input: z})
        for j in range(4):
            # Generate image from noise. Extend to 3 channels for matplot figure.
            img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),
                             newshape=(28, 28, 3))
            a[j][i].imshow(img)

    f.show()
    plt.draw()
    plt.waitforbuttonpress()

<think>嗯,用户这次在询问PyTorch实现DCGAN的代码示例和资料。用户之前询问过快速排序的实现,现在突然转向深度学习框架的生成对抗网络实现,可能是研究者或工程实践者,需要具体的技术资源。用户提供的三个引用材料很有针对性:引用1提到PyTorch教程资源,特别标注了GAN实现部分;引用2是PyTorch官方DCGAN教程的英文原文,详细解释了实现原理;引用3是中文入门教程。根据用户明确要求“PyTorchDCGANimplementationexample”,需要聚焦代码层面的实用资源。用户提到“参考站内引用但并非真实需求”,说明引文只是线索,实际需要最新可运行的DCGAN示例。因此考虑推荐三个层级的资源:官方权威实现、代码详解教程、中文简化版本。其中官方DCGAN示例库必须首选,这是经过社区验证的完整项目;pytorch-beginner作为补充更适合调试学习;中文博客则降低理解门槛。注意到用户没有要求理论解释,重点在“代码示例”,所以回复结构突出具体GitHub链接和可运行的代码块。考虑把pytorch/examples项目链接放在最前,因为它包含dataloader等完整工程模块。额外想到需提醒GPU配置问题,因为引用2提到“需要一两个GPU”——这直接影响用户能否运行代码,考虑在注意事项里强调环境要求。用户似乎习惯技术型答案,可以添加三个自然衍生问题:性能提升、训练技巧、迁移应用。这些问题都源自实际开发中必然遇到的调试需求,比如模式坍塌在GAN训练中很常见,相关论文的扩展阅读也有实际价值。</think> 以下是关于PyTorch实现DCGAN的精选资源和代码示例整理,结合站内引用及权威资料: --- ### **1. 官方DCGAN实现(最核心推荐)** **资源**:PyTorch官方示例库 ```python # 克隆官方DCGAN代码库 git clone https://github.com/pytorch/examples.git cd examples/dcgan ``` **关键特点**: - 完整实现深度卷积生成对抗网络(DCGAN)[^2] - 支持训练自定义数据集 - 提供预训练模型生成图像 - 详细注释网络结构(生成器Generator & 判别器Discriminator) > **文件结构**: > `main.py`(主训练逻辑) > `model.py`(网络架构定义) > `requirements.txt`(依赖项) --- ### **2. 实战教程与代码解析** **(1) PyTorch官方DCGAN教程** - **链接**:[DCGAN Tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html) - **内容**: - 从数据集加载到训练流程完整代码 - 关键组件逐行解析(如`nn.ConvTranspose2d`上采样) - 可视化训练过程中的生成图像 ```python # 生成器核心代码片段 class Generator(nn.Module): def __init__(self): super().__init__() self.main = nn.Sequential( nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), # ... 更多层定义 ) ``` **(2) 中文简化实现(SherlockLiao/pytorch-beginner)** - **GitHub库**:[pytorch-beginner - GAN示例](https://github.com/SherlockLiao/pytorch-beginner/tree/master/08-GAN) - **优势**: - 更简洁的代码结构(约200行) - 适配MNIST/CelebA等常见数据集 - 含训练进度可视化函数 --- ### **3. 扩展学习资源** 1. **MorvanZhou教程**: [PyTorch-GAN实现](https://github.com/MorvanZhou/PyTorch-Tutorial/tree/master/tutorial-contents)(含GAN基础原理动画) 2. **论文复现项目**: [DCGAN-PyTorch](https://github.com/carpedm20/DCGAN-pytorch)(社区高星项目,含训练技巧) --- ### **注意事项** - **硬件要求**:DCGAN训练需GPU加速(显存≥4GB)[^2] - **数据集准备**:若使用自定义数据,需调整`torchvision.datasets.ImageFolder`的数据路径 - **参数调优**:初始学习率建议0.0002,批大小≥64(避免模式坍塌) > ⚠️ 避免的误区:非官方库中未经验证的GAN实现(如部分博客代码)常存在梯度计算错误或架构设计缺陷。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值