Generative Adverserial Networks in PyTorch

Deep neural networks are used mainly for supervised learning: classification or regression. Generative Adverserial Networks or GANs, however, use neural networks for a very different purpose: Generative modeling

Generative modeling is an unsupervised learning task in machine learning that involves automatically discovering and learning the regularities or patterns in input data in such a way that the model can be used to generate or output new examples that plausibly could have been drawn from the original dataset. - Source

While there are many approaches used for generative modeling, a Generative Adverserial Network takes the following approach:

GAN Flowchart

There are two neural networks: a Generator and a Discriminator. The generator generates a "fake" sample given a random vector/matrix, and the discriminator attempts to detect whether a given sample is "real" (picked from the training data) or "fake" (generated by the generator). Training happens in tandem: we train the discriminator for a few epochs, then train the generator for a few epochs, and repeat. This way both the generator and the discriminator get better at doing their jobs. This rather simple approach can lead to some astounding results. The following images (source), for instances, were all generated using GANs:

gans_results

GANs however, can be notoriously difficult to train, and are extremely sensitive to hyperparameters, activation functions and regularization. In this tutorial, we'll train a GAN to generate images of handwritten digits similar to those from the MNIST database.

alt

Most of the code for this tutorial has been borrowed for this excellent repository of PyTorch tutorials: github.com/yunjey/pytorch-tutorial. Here's what we're going to do:

  • Define the problem statement
  • Load the data (with transforms and normalization)
    • Denormalize for visual inspection of samples
  • Define the Discriminator network
    • Study the activation function: Leaky ReLU
  • Define the Generator network
    • Explain the output activation function: TanH
    • Look at some sample outputs
  • Define losses, optimizers and helper functions for training
    • For discriminator
    • For generator
  • Train the model
    • Save intermediate generated images to file
  • Look at some outputs
  • Save the models
  • Commit to Jovian.ml

Load the Data

We begin by downloading and importing the data as a PyTorch dataset using the MNIST helper class from torchvision.datasets.

import torch
import torchvision
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision.datasets import MNIST

mnist = MNIST(root='data', 
              train=True, 
              download=True,
              transform=Compose([ToTensor(), Normalize(mean=(0.5,), std=(0.5,))]))
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
100.1%
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
113.5%
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
100.4%
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
180.4%
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw Processing... Done!

Note that we are are transforming the pixel values from the range [0, 1] to the range [-1, 1]. The reason for doing this will become clear when define the generator network. Let's look at a sample tensor from the data.

img, label = mnist[0]
print('Label: ', label)
print(img[:,10:15,10:15])
torch.min(img), torch.max(img)
Label: 5 tensor([[[-0.9922, 0.2078, 0.9843, -0.2941, -1.0000], [-1.0000, 0.0902, 0.9843, 0.4902, -0.9843], [-1.0000, -0.9137, 0.4902, 0.9843, -0.4510], [-1.0000, -1.0000, -0.7255, 0.8902, 0.7647], [-1.0000, -1.0000, -1.0000, -0.3647, 0.8824]]])
(tensor(-1.), tensor(1.))