Learn how to generate fake faces using Deep Convolutional Generative Adversarial Networks (DCGAN). This blog explains the theory behind GANs and DCGANs, and how to implement them using PyTorch.
What is a GAN?
GANs are a framework for teaching a DL model to capture the training data’s distribution so we can generate new data from that same distribution. GANs were invented by Ian Goodfellow in 2014 and first described in the paper `Generative Adversarial Nets <https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf>`__. They are made of two distinct models, a *generator* and a *discriminator*. The job of the generator is to spawn ‘fake’ images that look like the training images. The job of the discriminator is to look at an image and output whether or not it is a real training image or a fake image from the generator. During training, the generator is constantly trying to outsmart the discriminator by generating better and better fakes, while the discriminator is working to become a better detective and correctly classify the real and fake images. The equilibrium of this game is when the generator is generating perfect fakes that look as if they came directly from the training data, and the discriminator is left to always guess at 50% confidence that the generator output is real or fake. Now, lets define some notation to be used throughout tutorial starting with the discriminator. Let x be data representing an image. D(x) is the discriminator network which outputs the (scalar) probability that x came from training data rather than the generator. Here, since we are dealing with images the input to D(x) is an image of CHW size 3x64x64. Intuitively, D(x) should be HIGH when x comes from training data and LOW when x comes from the generator. D(x) can also be thought of as a traditional binary classifier. For the generator’s notation, let z be a latent space vector sampled from a standard normal distribution. G(z) represents the generator function which maps the latent vector z to data-space. The goal of G is to estimate the distribution that the training data comes from (Pdata) so it can generate fake samples from that estimated distribution (Pg). So, D(G(z)) is the probability (scalar) that the output of the generator G is a real image. As described in `Goodfellow’s paper <https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf>`__, D and G play a minimax game in which D tries to maximize the probability it correctly classifies reals and fakes (logD(x)), and G tries to minimize the probability that D will predict its outputs are fake (log(1-D(G(x)))). In theory, the solution to this minimax game is where Pg = Pdata, and the discriminator guesses randomly if the inputs are real or fake. However, the convergence theory of GANs is still being actively researched and in reality models do not always train to this point. What is a DCGAN?
A DCGAN is a direct extension of the GAN described above, except that it
explicitly uses convolutional and convolutional-transpose layers in the
discriminator and generator, respectively. It was first described by
Radford et. al. in the paper
Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks <https://arxiv.org/pdf/1511.06434.pdf>. The discriminator
is made up of strided
batch norm <https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm2d>__
activations. The input is a 3x64x64 input image and the output is a
scalar probability that the input is from the real data distribution.
The generator is comprised of
layers, batch norm layers, and
ReLU <https://pytorch.org/docs/stable/nn.html#relu>__ activations. The
input is a latent vector, z, that is drawn from a standard
normal distribution and the output is a 3x64x64 RGB image. The strided
conv-transpose layers allow the latent vector to be transformed into a
volume with the same shape as an image. In the paper, the authors also
give some tips about how to setup the optimizers, how to calculate the
loss functions, and how to initialize the model weights, which
will be explained in the coming projects.
from __future__ import print_function #%matplotlib inline import argparse import os import random import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data import torchvision.datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation from IPython.display import HTML # Set random seed for reproducibility manualSeed = 999 #manualSeed = random.randint(1, 10000) # use if you want new results print("Random Seed: ", manualSeed) random.seed(manualSeed) torch.manual_seed(manualSeed)
Random Seed: 999
<torch._C.Generator at 0x7f2fad0741c8>
Let’s define some inputs for the run:
here <https://github.com/pytorch/examples/issues/70>__ for more details