Open In App

Image Generation using Generative Adversarial Networks (GANs) using TensorFlow

Last Updated : 29 May, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Generative Adversarial Networks (GANs) revolutionized AI image generation by creating realistic and high-quality images from random noise. In this article, we will train a GAN model on the MNIST dataset to generate handwritten digit images.

Training GANs for Image Generation

Generative Adversarial Networks (GANs) consist of two neural networks the Generator and the Discriminator that compete with each other. Generator creates images from random noise while the Discriminator evaluates images to classify them as real or fake which leads to continuous improvement in the quality of generated samples.

Training the Discriminator

The Discriminator starts by being trained on a dataset containing real images. Its goal is to differentiate between these real images and fake images generated by the Generator. Through backpropagation and gradient descent it adjusts its parameters to improve its ability to accurately classify real and generated images.

Training the Generator

In parallel, Generator is trained to produce images that are increasingly difficult for the Discriminator to distinguish from real images. Initially it generates random noise but as training progresses it learns to generate images that resemble those in the training dataset. Generator's parameters are adjusted based on the feedback from the Discriminator helps in optimizing the Generator's ability to create more realistic and high-quality images.

Implementing Generative Adversarial Networks (GANs) for Image Generation

Lets see various steps involved in this implementation.

Step 1: Import Necessary Libraries and Load Dataset

We will be using TensorFlow, Keras, NumPy and Matplotlib.

Python
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import matplotlib.pyplot as plt

Step 2: Dataset Preparation

Prepare the MNIST data by reshaping images to the required format and normalizing pixel values to the range [0,1]. Normalization helps stabilize training by keeping input values within a small range.

Python
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape((-1, 28, 28, 1)).astype('float32') / 255.0

Step 3: Building the Models

Define the architectures of the generator and discriminator using convolutional neural networks (CNNs) designed to efficiently process and generate images.

Generator Model with CNN Layers:

  • Dense Layer: Converts the 100-dimensional noise vector into a high-dimensional feature map.
  • Reshape: Transforms the feature map into a 3D tensor suitable for convolutional processing.
  • Conv2DTranspose Layers: Perform upsampling and convolution simultaneously helps in gradually increasing image resolution.
  • BatchNormalization: Stabilizes training and speeds convergence.
  • Activation Functions: ReLU for hidden layers and sigmoid in the output layer to constrain pixel values between 0 and 1.

Discriminator Model with CNN Layers:

  • Conv2D Layers: Apply stride-2 convolutions to downsample images helps in reducing dimensionality and increasing receptive fields.
  • BatchNormalization: Helps to maintain stable training.
  • Flatten: Converts feature maps into 1D vectors for classification.
  • Dense Output Layer: Outputs a single probability indicating whether an image is real or fake.
Python
def build_discriminator_cnn():
    model = models.Sequential([
        
        layers.Conv2D(64, kernel_size=3, strides=2, input_shape=(28, 28, 1), padding='same', activation='relu'),
        
        layers.Conv2D(128, kernel_size=3, strides=2, padding='same', activation='relu'),
        layers.BatchNormalization(),

        layers.Flatten(),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

Step 4: Compiling the Models

Compile the GAN by connecting the generator and discriminator. During generator training, the discriminator’s weights are frozen to prevent updates.

  • discriminator_cnn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss='binary_crossentropy', metrics=['accuracy']): Compiles the discriminator with Adam optimizer, binary cross-entropy loss, and accuracy metric.
  • gan_input = layers.Input(shape=(100,)): Defines the input layer for the GAN, expecting a 100-dimensional noise vector.
Python
generator_cnn = build_generator_cnn()
discriminator_cnn = build_discriminator_cnn()

discriminator_cnn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), 
                          loss='binary_crossentropy', 
                          metrics=['accuracy'])

discriminator_cnn.trainable = False

gan_input = layers.Input(shape=(100,))
gan_output = discriminator_cnn(generator_cnn(gan_input))
gan_cnn = models.Model(gan_input, gan_output)
gan_cnn.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), 
                loss='binary_crossentropy')

Step 5: Model Training and Visualizing

Train the GAN by alternating between:

  • Training the discriminator to correctly classify real and fake images.
  • Training the generator to fool the discriminator by producing more realistic images.

Visualize generated images periodically to monitor training progress.

  • idx = np.random.randint(0, x_train.shape[0], batch_size): Randomly selects indices to sample a batch of real images from the training data.
  • g_loss = gan_cnn.train_on_batch(noise, valid_labels): Trains the GAN model (generator + frozen discriminator) to improve generator performance.
Python
epochs = 100000
batch_size = 64

for epoch in range(epochs+1):
   
    noise = np.random.normal(0, 1, (batch_size, 100))
    generated_images = generator_cnn.predict(noise)

    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_images = x_train[idx]

    real_labels = np.ones((batch_size, 1))
    fake_labels = np.zeros((batch_size, 1))

    d_loss_real = discriminator_cnn.train_on_batch(real_images, real_labels)
    d_loss_fake = discriminator_cnn.train_on_batch(generated_images, fake_labels)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    noise = np.random.normal(0, 1, (batch_size, 100))
    valid_labels = np.ones((batch_size, 1))
    g_loss = gan_cnn.train_on_batch(noise, valid_labels)

    if epoch % 100 == 0:
        print(f"Epoch {epoch}: D Loss: {d_loss[0]}, G Loss: {g_loss}")

    if epoch % 1000 == 0:
        test_noise = np.random.normal(0, 1, (1, 100))
        test_img = generator_cnn.predict(test_noise)[0].reshape(28, 28)
        plt.imshow(test_img, cmap='gray')
        plt.axis('off')
        plt.show()

Output:

Screenshot-(930)
Generated Images

If the output image is not clear, you can Fine-tune the model using different parameters for better accuracy and better results.

Challenges and Considerations

Training GANs comes with a few common challenges that can affect performance and stability:

  • Mode Collapse: The generator produces very limited or repetitive outputs which causes failing to capture the full variety of the training data.
  • Training Instability: The training process becomes unstable which causes the generator and discriminator to oscillate or diverge instead of improving together.
  • Hyperparameter Sensitivity: GAN performance is highly sensitive to choices like learning rate, optimizer settings, and model architecture, often requiring careful tuning.

As we train GANs to generate handwritten digits we take a meaningful step toward teaching machines not just to learn from data but to imagine new possibilities from it.


Next Article

Similar Reads