From Scratch: GAN

In this post, we build a generative adversarial network (GAN) from scratch using only primitive types in tensorflow.js.

Theory

Today, generative models are ubiquitous in the form of generative adversarial networks (GAN) [1], variational autoencoders (VAE) [3], transformer-based generative architectures [2], or diffusion-based generative models [4].

Let’s start out with the basic idea of a GAN from the original paper [1]. A GAN is defined in terms of a discriminator \(D(x)\) and a generator \(G(z)\) that perform a min-max game: The generator tries to generate data according to the data distribution \(p_{\text{data}}(x)\) while the discriminator tries to differentiate between real and generated samples. Both the discriminator \(D\) and generator \(G\) are differentiable functions implemented as neural networks \(D(x; \theta_d)\) and \(G(z; \theta_g)\) with the parameters \(\theta_d\) and \(\theta_g\). The training of a GAN works in-tandem, such that the generator creates samples that match more closely the true distribution, while the discriminator becomes better at distinguishing generated and real samples. A key feature of the GAN formulation is that the generator does not have access to the data distribution and is only trained on the gradient of the discriminator.

$$ \min_G \max_D \, \underbrace{\mathbb{E}_{x \sim p_{\text{data}}(x)}\left[\log D(x)\right]}_{V_{\text{real}}(D)} + \underbrace{\mathbb{E}_{z \sim p_{z}(z)}\left[\log(1 - D(G(z)))\right]}_{V_{\text{fake}}(D,G)} $$

We can separate this training objective into two parts:

$$ V_{\text{real}}(D) = \mathbb{E}_{x \sim p_{\text{data}}(x)}\left[\log D(x)\right] $$

The first part is concerned with priming the discriminator to recognize samples from the data distribution \(p_{\text{data}}(x)\). The discriminator \(D(x)\) is a binary classifier that tries to output a value close to \(1\) if \(x\) is real and a value close to \(0\) if \(x\) is fake, i.e., generated by the generator. With the loss function above, we train the discriminator to maximize the expected value of recognizing samples from the distribution \(p_{\text{data}}(x)\), i.e., we want the discriminator to become good at detecting real samples.

Using only this formulation has an obvious drawback: A shortcut to solving this task would be to just output \(1\) one every input sample, which would maximize the expected value for recognizing samples from the data distribution. What we are missing is a way to penalize the discriminator for (mis-)classifying generated samples as real.

$$ V_{\text{fake}}(D,G) = \mathbb{E}_{z \sim p_{z}(z)}\left[\log(1 - D(G(z)))\right] $$

This is done by the second part of the GAN training objective. Here, we sample from a noise distribution \(z \sim p_{z}(z)\) that is used as an input for the generator \(G(z)\) which is then fed into the discriminator \(D(G(z))\). In the GAN-game, where the discriminator tries to maximize, this expected value, this results in penalizing the discriminator for assigning values close to \(1\) to fake samples.

Implementation

We use a one-dimensional input space \(x \in \mathbb{R}^1\) and latent space \(z \in \mathbb{R}^1\) in our GAN implementation.

Our objective is to learn the data distribution $$ p_{\text{data}}(x) \sim \mathcal{N}(-3, 0.75) $$ with our GAN implementation. We use the built-in tensorflow function tf.randomNormal to implement a function with which we can take samples from our data distribution:

const d_input_dim = 1;
const p_data_mean = -3;
const p_data_std = 0.75;

function p_data(m: number) {
    return tf.randomNormal([m, d_input_dim], p_data_mean, p_data_std);
}

Next, we have our latent distribution. This is simply the normal distribution with a mean of zero and standard deviation of one $$ p_{z}(z) \sim \mathcal{N}(0, 1) $$ which, too, is implemented using the tf.randomNormal function:

const z_dim = 1;

function p_z(m: number) {
    return tf.randomNormal([m, z_dim], 0, 1);
}

Formally, the discriminator \(D\) implements the function: $$ D(x; \theta_d): \mathbb{R}^{1} \rightarrow [0, 1] $$ In our implementation, we use a three layer network: Two dense layers with tanh activations and a final sigmoid layer to normalize the output between 0 and 1.

const D = new Discriminator(d_input_dim, 16);

We additionally use a dropout layer after each dense layer in the discriminator as suggested in [5, Section 20.10.4]

Dropout seems to be important in the discriminator network. In particular, units should be stochastically dropped while computing the gradient for the generator network to follow. [...] never using dropout seems to yield poor results.

We actually have to be careful here and this is unique to the design of tensorflow.js. During training of the generator, we do not want the discriminator to update its weights, so we set the layer to non-trainable with D.trainable = false. However, during the forward pass we want the dropout layers to be included, which is why we pass the {training: true} to the apply function.

class Discriminator extends tf.LayersModel {
    constructor(input_dim: number,  ...hidden_dims: number[]) {
        const x: any = tf.input({ shape: [input_dim] });

        const layers: tf.layers.Layer[] = [];
        for (let hidden_dim of hidden_dims) {
            layers.push(tf.layers.dense({
                units: hidden_dim,
                activation: "tanh",
                kernelInitializer: "glorotNormal",
            }));
            layers.push(tf.layers.dropout({ rate: 0.5 }));
        }
        layers.push(tf.layers.dense({
            units: 1,
            activation: "sigmoid",
            kernelInitializer: "glorotNormal",
        }));

        const y = layers.reduce((x_i, layer) => layer.apply(x_i), x);
        super({ inputs: x, outputs: y });
    }
}

Formally, the generator \(G\) implements the function: $$ G(z; \theta_g): \mathbb{R}^{1} \rightarrow \mathbb{R}^{1} \\ $$ In our implementation, we use a three layer network with ReLU activations, except for the final, unnormalized output.

const G = new Generator(z_dim, 16, d_input_dim);
class Generator extends tf.LayersModel {
    constructor(input_dim: number, dim: number, ...more_dims: number[]) {
        const x: any = tf.input({ shape: [input_dim] });
        const hidden_dims = [dim].concat(more_dims);
        const output_dims = hidden_dims.splice(-1, 1)[0];

        const layers: tf.layers.Layer[] = [];
        for (let hidden_dim of hidden_dims) {
            layers.push(tf.layers.dense({
                units: hidden_dim,
                activation: "relu",
                kernelInitializer: "glorotNormal",
            }));
        }
        layers.push(tf.layers.dense({
            units: output_dims,
            kernelInitializer: "glorotNormal",
        }));

        const y = layers.reduce((x_i, layer) => layer.apply(x_i), x);
        super({ inputs: x, outputs: y });
    }
}

We combine the generator and discriminator into the final GAN:

class GenerativeAdversarialNet extends tf.LayersModel {
    constructor(
        public input_dim: number,
        public G: Generator,
        public D: Discriminator
    ) {
        const input = tf.input({ shape: [input_dim] });
        super({ inputs: input, outputs: D.apply(G.apply(input)) });
    }
}

Training

Discriminator Step

Generator Step