In this post, I discuss our recent paper, Categorical Reparameterization with Gumbel-Softmax, which introduces a simple technique for training neural networks with discrete latent variables. I'm really excited to share this because (1) I believe it will be quite useful for a variety of Machine Learning research problems, (2) this is my first published paper ever (on Arxiv, and submitted to a NIPS workshop and ICLR as well).

**The TLDR;** if you want categorical features in your neural nets, just let `sample = softmax((logits+gumbel noise)/temperature)`

, and then backprop as usual using your favorite automatic differentiation software (e.g. TensorFlow, Torch, Theano).

You can find the code for this article here

## Introduction

One of the main themes in Deep Learning is to “let the neural net figure out all the intermediate features”. For example: training convolutional neural networks results in the self-organization of a feature detector hierarchy, while Neural Turing Machines automatically “discover” copying and sorting algorithms.

The workhorse of Deep Learning is the backpropagation algorithm, which uses dynamic programming to compute parameter gradients of the network. These gradients are then used to minimize the optimization objective via gradient descent. In order for this to work, all of the layers in our neural network — i.e. our learned intermediate features — must be continuous-valued functions.

What happens if we want to learn intermediate representations that are discrete? Many "codes" we want to learn are fundamentally discrete - musical notes on a keyboard, object classes (“kitten”, “balloon”, “truck”), and quantized addresses (“index 423 in memory”).

We can use stochastic neural networks, where each layer compute the parameters of some (discrete) distribution, and its forward pass consists of taking a sample from that parametric distribution. However, the difficulty is that we can’t backpropagate through samples. As shown below, there is a stochastic node (blue circle) in between $f(z)$ and $\theta$.

Left: in continuous neural nets, you can use backprop to compute parameter gradients. Right: backpropagation is not possible through stochastic nodes.

## Gumbel-Softmax Distribution

The problem of backpropagating through stochastic nodes can be circumvented if we can re-express the sample $z \sim p_\theta(z)$, such that gradients can flow from $f(z)$ to $\theta$ without encountering stochastic nodes. For example, samples from the normal distribution $z \sim \mathcal{N}(\mu,\sigma)$ can be re-written as $z = \mu + \sigma \cdot \epsilon$, where $\epsilon \sim \mathcal{N}(0,1)$. This is also known as the “reparameterization trick”, and is commonly used to train variational autoencoders with Gaussian latent variables.

The Gumbel-Softmax distribution is reparameterizable, allowing us to avoid the stochastic node during backpropagation.

**The main contribution of this work is a “reparameterization trick” for the categorical distribution.** Well, not quite – it’s actually a re-parameterization trick for a distribution that we can *smoothly deform* into the categorical distribution. We use the Gumbel-Max trick, which provides an efficient way to draw samples $z$ from the Categorical distribution with class probabilities $\pi_i$:

argmax is not differentiable, so we simply use the softmax function as a continuous approximation of argmax:

$$ y_i = \frac{\text{exp}((\log(\pi_i)+g_i)/\tau)}{\sum_{j=1}^k \text{exp}((\log(\pi_j)+g_j)/\tau)} \qquad \text{for } i=1, ..., k. $$Hence, we call this the Gumbel-*Soft*Max distribution*. $\tau$ is a temperature parameter that allows us to control how closely samples from the Gumbel-Softmax distribution approximate those from the categorical distribution. As $\tau \to 0$, the softmax becomes an argmax and the Gumbel-Softmax distribution becomes the categorical distribution. During training, we let $\tau > 0$ to allow gradients past the sample, then gradually anneal the temperature $\tau$ (but not completely to 0, as the gradients would blow up).

Below is an interactive widget that draws samples from the Gumbel-Softmax distribution. Keep in mind that samples are vectors, and a one-hot vector (i.e. one of the elements is 1.0 and the others are 0.0) corresponds to a discrete category. Click "re-sample" to generate a new sample, and try dragging the slider and see what samples look like when the temperature $\tau$ is small!

1.0## TensorFlow Implementation

Using this technique is extremely simple, and only requires 12 lines of Python code:

Despite its simplicity, Gumbel-Softmax works surprisingly well - we benchmarked it against other stochastic gradient estimators for a couple tasks and Gumbel-Softmax outperformed them for both Bernoulli (K=2) and Categorical (K=10) latent variables. We can also use it to train semi-supervised classification models much faster than previous approaches. See our paper for more details.

## Categorical VAE with Gumbel-Softmax

To demonstrate this technique in practice, here's a categorical variational autoencoder for MNIST, implemented in less than 100 lines of Python + TensorFlow code.

In standard Variational Autoencoders, we learn an encoding function that maps the data manifold to an isotropic Gaussian, and a decoding function that transforms it back to the sample. The data manifold is projected into a Gaussian ball; this can be hard to interpret if you are trying to learn the categorical structure within your data.

First, we declare the encoding network:

Next, we sample from the Gumbel-Softmax posterior and decode it back into our MNIST image.

Variational autoencoders minimizes reconstruction error of the data by maximizing an expectedlower bound (ELBO) on the likelihood of the data, under a generative model $p_\theta(x)$. For a derivation, see this tutorial on variational methods.

$$\log p_\theta(x) \geq \mathbb{E}_{q_\phi(y|x)}[\log p_\theta(x|y)] - KL[q_\phi(y|x)||p_\theta(y)]$$Finally, we run train our VAE:

...and, that's it! Now we can sample randomly from our latent categorical code and decode it back into MNIST images:

Code can be found here. Thank you for reading, and let me know if you find this technique useful!

## Acknowledgements

I'm sincerely grateful to my co-authors, Shane Gu and Ben Poole for collaborating with me on this work. Shane introduced me to the Gumbel-Max trick back in August, and supplied the framework for comparing Gumbel-Softmax with existing stochastic gradient estimators. Ben suggested and implemented the semi-supervised learning aspect of the paper, did the math derivations in the Appendix, and helped me a lot with editing the paper. Finally, thanks to Vincent Vanhoucke and the Google Brain team for encouraging me to pursue this idea.

*Chris J. Maddison, Andriy Mnih, and Yee Whye Teh at Deepmind have discovered this technique independently and published their own paper on it - they call it the “Concrete Distribution”. We only found out about each other’s work right as we were submitting our papers to conferences (oops!). If you use this technique in your work, please cite both of our papers! They deserve just as much credit.