Building an Auxiliary GAN using Keras and Tensorflow
Last Updated :
15 Sep, 2021
Prerequisites: Generative Adversarial Network
This article will demonstrate how to build an Auxiliary Generative Adversarial Network using the Keras and TensorFlow libraries. The dataset which is used is the MNIST Image dataset pre-loaded into Keras.
Step 1: Setting up the environment
Step 1 : Open Anaconda prompt in Administrator mode.
Step 2 : Create a virtual environment using the command : conda create --name acgan python=3.7
Step 3 : Then, activate the environment using the command : conda activate acgan
Step 4 : Install the following libraries -
4.1 - Tensorflow --> pip install tensorflow==2.1
4.2 - Keras --> pip install keras==2.3.1
Step 2: Importing the required libraries
Python3
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.models import Sequential, Model
from keras.layers.advanced_activations import LeakyReLU
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
Step 3: Defining parameters to be used in later processes
Python3
# Defining the Input shape
image_shape = (28, 28, 1)
classes = 10
latent_dim = 100
# Defining the optimizer and the losses
optimizer = Adam(0.0002, 0.5)
losses = ['binary_crossentropy','sparse_categorical_crossentropy']
Step 4: Defining a utility function to build the Generator
Python3
def build_generator():
model = Sequential()
# Building the input layer
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.82))
model.add(UpSampling2D())
model.add(Conv2D(128, (3,3), padding="same"))
model.add(BatchNormalization(momentum=0.82))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, (3,3), padding="same"))
model.add(BatchNormalization(momentum=0.82))
model.add(Activation("relu"))
model.add(Conv2D(1, (3,3), padding='same'))
model.add(Activation("tanh"))
# Generating the output image
noise = Input(shape=(latent_dim,))
label = Input(shape=(1,), dtype='int32')
z = Flatten()(Embedding(classes, latent_dim)(label))
model_input = multiply([noise, z])
image = model(model_input)
return Model([noise, label], image)
Step 5: Defining a utility function to build the Discriminator
Python3
def build_discriminator():
model = Sequential()
# Building the input layer
model.add(Conv2D(16, (3,3), strides=2, input_shape=image_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, (3,3), strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, (3,3), strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, (3,3), strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
image = Input(shape=image_shape)
# Extract features from images
features = model(image)
# Building the output layer
validity = Dense(1, activation="sigmoid")(features)
label = Dense(classes, activation="softmax")(features)
return Model(image, [validity, label])
Step 6: Defining a utility function to display the generated images
Python3
def display_images():
r = 10
c = 10
noise = np.random.normal(0, 1, (r * c,latent_dim))
new_labels = np.array([num for _ in range(r) for num in range(c)])
gen_images = generator.predict([noise, new_labels])
# Rescale images 0 - 1
gen_images = 0.5 * gen_images + 0.5
fig, axs = plt.subplots(r, c)
count = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_images[count,:,:,0], cmap='gray')
axs[i,j].axis('off')
count += 1
plt.show()
plt.close()
Step 7: Building and Training the AC-GAN
Python3
def train_acgan(epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X, y), (_, _) = mnist.load_data()
# Configure inputs
X = X.astype(np.float32)
X = (X - 127.5) / 127.5
X = np.expand_dims(X, axis=3)
y = y.reshape(-1, 1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# Select a random batch of images
index = np.random.randint(0, X.shape[0], batch_size)
images = X[index]
# Sample noise as generator input
noise = np.random.normal(0, 1, (batch_size, latent_dim))
# The labels of the digits that the generator tries to create an
# image representation of
new_labels = np.random.randint(0, 10, (batch_size, 1))
# Generate a half batch of new images
gen_images = generator.predict([noise, new_labels])
image_labels = y[index]
# Training the discriminator
disc_loss_real = discriminator.train_on_batch(
images, [valid, image_labels])
disc_loss_fake = discriminator.train_on_batch(
gen_images, [fake, new_labels])
disc_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)
# Training the generator
gen_loss = combined.train_on_batch(
[noise, new_labels], [valid, new_labels])
# Print the accuracies
print ("%d [acc.: %.2f%%, op_acc: %.2f%%]" % (
epoch, 100 * disc_loss[3], 100 * disc_loss[4]))
# display at every defined epoch interval
if epoch % sample_interval == 0:
display_images()
Step 8: Building the Generative Adversarial Network
Python3
# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss=losses,
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
generator = build_generator()
# Defining the input for the generator
#and generating the images
noise = Input(shape=(latent_dim,))
label = Input(shape=(1,))
image = generator([noise, label])
# Disable the Discriminator
# For the combined model we will only train the generator
discriminator.trainable = False
# The discriminator takes in the generated image
# as input and determines validity
# and the label of that image
valid, target_label = discriminator(image)
# The combined model (both generator and discriminator)
# Training the generator to fool the discriminator
combined = Model([noise, label], [valid, target_label])
combined.compile(loss=losses, optimizer=optimizer)
train_acgan(epochs=14000, batch_size=32, sample_interval=2000)
Output (At every 2000 epoch interval):
Epoch 0
Epoch 2000
Epoch 4000
Epoch 6000
Epoch 8000
Epoch 10000
Epoch 12000
Epoch 14000
FINAL RESULT
On visually observing the progression of generated images, it can be concluded that the network is working at an acceptable level. The quality of images can be improved by training the network for more time or by tuning the parameters of the network. For any doubts/queries, comment below.
Similar Reads
Building a Generative Adversarial Network using Keras
Generative Adversarial Networks (GANs)are deep learning models that involve two neural networks: generator and a discriminator. These networks work in a setup where they are trained together in an adversarial manner.The generator tries to generate fake data that is made from real data.While the disc
5 min read
An Introduction to Keras and TensorFlow in R
Keras and TensorFlow are two of the most popular libraries for deep learning, widely used in the fields of artificial intelligence, machine learning, and data science. While originally developed for Python, both Keras and TensorFlow can be used in R, making it possible for R users to leverage these
5 min read
Building an Auto-Encoder using Keras
Prerequisites: Auto-encoders This article will demonstrate the process of data compression and the reconstruction of the encoded data by using Machine Learning by first building an Auto-encoder using Keras and then reconstructing the encoded data and visualizing the reconstruction. We would be using
3 min read
Save and Load Models using TensorFlow in Json?
If you are looking to explore Machine Learning with TensorFlow, you are at the right place. This comprehensive article explains how to save and load the models in TensorFlow along with its brief overview. If you read this article till the end, you will not need to look for further guides on how to s
6 min read
Training a Neural Network using Keras API in Tensorflow
In the field of machine learning and deep learning has been significantly transformed by tools like TensorFlow and Keras. TensorFlow, developed by Google, is an open-source platform that provides a comprehensive ecosystem for machine learning. Keras, now fully integrated into TensorFlow, offers a us
3 min read
Implementing Deep Q-Learning using Tensorflow
Deep Q-Learning is a reinforcement learning method which uses a neural network to help an agent learn how to make decisions by estimating Q-values which represent how good an action is in a given situation. In this article weâll implement Deep Q-Learning from scratch using PyTorch.How Deep Q-Learnin
5 min read
Understanding Auxiliary Classifier : GAN
Prerequisite: GANs(General Adversarial Networks) In this article, we will be discussing a special class conditional GAN or c-GAN known as Auxiliary Classifier GAN or AC-GAN. Before getting into that, it is important to understand what a class conditional GAN is. Class-Conditional GAN (c-GANs): c-GAN
4 min read
How to Use TensorFlow in a Multi-Task Learning Scenario
Multi-task learning (MTL) is a branch of machine learning where multiple learning tasks are solved together, sharing commonalities and differences across them. This approach can lead to improved learning efficiency and prediction accuracy for individual tasks. TensorFlow, a comprehensive, flexible f
6 min read
Random number generation using TensorFlow
In the field of Machine Learning, Random numbers generation plays an important role by providing stochasticity essential for model training, initialization, and augmentation. We have TensorFlow, a powerful open-source machine learning library, that contains tf.random module. This module helps us for
6 min read
Sudoku Solver using TensorFlow
The goal of the project is to build a Sudoku solver that can complete Sudoku problems autonomously using the capabilities of TensorFlow, a Google open-source machine learning toolkit. The algorithm aims to recognize patterns and relationships within the incomplete grids; the solver will be able to p
9 min read