This post is based off of a tutorial on normalizing flows I gave at the ICML workshop on Invertible Neural Nets and Normalizing Flows. I've already written about how to implement your own flows in TensorFlow using TensorFlow Probability's Bijector API, so to make things interesting I wanted to show how to implement Real-NVP a different way.

By the end of this tutorial you'll be able to reproduce this figure of a normalizing flow "bending" samples from a 2D Normal distribution to samples from the "Two Moons" dataset. Real-NVP forms the basis of a lot of flow-based architectures (as of 2019), so this is a good template to start learning from.

### Install Dependencies

There are just a few dependencies required to reproduce this tutorial. We'll be running everything on the CPU, though you can also build the GPU-enabled versions of JAX if you have the requisite hardware.

`pip install --upgrade jax jaxlib scikit-learn matplotlib`

### Toy Dataset

Scikit-Learn comes with some toy datasets that are useful for small scale density models.

### Affine Coupling Layer in JAX

TensorFlow probability defines an object-oriented API for building flows, where a "TransformedDistribution" object is given a base "Distribution" object along with a "Bijector" object that implements the invertible transformation. In pseudocode, it goes something like this:

However, programming in JAX takes on a functional programming philosophy where functions are stateless and classes are eschewed. That's okay: we can still build a similar API in a functional way. To make everything end-to-end differentiable via JAX's grad() operator, it's convenient to put the parameters that we want gradients for as the first argument of every function. Here are the sample and log_prob implementations of the base distribution.

Below are the forward and inverse functions of Real-NVP, which operates on minibatches (we could also re-implement this to operate over vectors, and use JAX's vmap operator to auto-batch it). Because we are dealing with 2D data, the masking scheme for Real-NVP is very simple: we just switch the masked variable every other flow via the "flip" parameter.

The "forward" NVP transformation takes in a callable shift_and_log_scale_fn (an arbitrary neural net that takes the masked variables as inputs), applies it to recover the shift and log scale parameters, transforms the un-masked inputs, and then stitches the masked scalar and the transformed scalar back together in the right order. The inverse does the opposite.

Here are the corresponding sampling (forward) and log-prob (inverse) implementations for a single RealNVP coupling layer. The ILDJ term is computed directly, as it is just the (negative) sum of the log_scale terms.

What should we use for our shift_and_log_scale_fn? I've found that for 2D data + NVP, wider and shallow neural nets tend to train more stably. We'll use some JAX helper libraries to build a function that initializes the parameters and callable function for a MLP with two hidden layers (512) and ReLU activations.

### Stacking Coupling Layers

TensorFlow Probability's object-oriented API is convenient because it allows us to "stack" multiple TransformedDistributions on top of each other for more expressive - yet tractable - transformations.

For "bipartite" flows like Real-NVP which leave some variables untouched, it is critical to be able to stack multiple flows so that all variables get a chance to be "transformed".

Here's the functional way to do the same thing in JAX. We have a function "init_nvp_chain" that returns neural net parameters, callable shift_and_log_scale_fns, and masking parameters for each flow. We then pass this big bag of parameters to the sample_nvp_chain function.

In log_prob_nvp_chain, there is an iteration loop that overrides log_prob_fn, which is initially set to base_log_prob_fn. This is to accomplish similar semantics to how TransformedDistribution.log_prob is defined with respect to the log_prob function of the base distribution beneath it. Python variable binding can be a bit tricky at times, and it's easy to make a mistake here that results in an infinite loop. The solution is to make a function generator (make_lob_prob_fn), that returns a function with the correct base log_prob_fn bound to the log_prob_nvp argument. Thanks to David Bieber for pointing this fix out to me.

### Training Real-NVP

Finally, we are ready to train this thing!

We initialize our Real-NVP with 4 affine coupling layers (each variable is transformed twice), define the optimization objective to be model negative log-likelihood over minibatches (more precisely, cross entropy).

Next, we declare a single optimization step where we retrieve the current optimizer state, compute gradients with respect to our big list of Real-NVP parameters, and then update our parameters. The cool thing about JAX is that we can "jit" (just-in-time compile) the step function to a single XLA op so that the entire optimization step happens without returning back to the (relatively slow) Python interpreter. We could even JIT the entire optimization process if we wanted to!

### Animation

Here's the code snippet that will visualize each of the 4 affine coupling layers transforming samples from the base Normal distribution, in sequence. Is it just me, or does anyone else find themselves constantly having to Google "How to make a Matplotlib animation?"