Saturday, July 6, 2019

Normalizing Flows in 100 Lines of JAX

JAX is a great linear algebra + automatic differentiation library for fast experimentation with and teaching machine learning. Here is a lightweight example, in just 75 lines of JAX, of how to implement Real-NVP.

This post is based off of a tutorial on normalizing flows I gave at the ICML workshop on Invertible Neural Nets and Normalizing Flows. I've already written about how to implement your own flows in TensorFlow using TensorFlow Probability's Bijector API, so to make things interesting I wanted to show how to implement Real-NVP a different way.

By the end of this tutorial you'll be able to reproduce this figure of a normalizing flow "bending" samples from a 2D Normal distribution to samples from the "Two Moons" dataset. Real-NVP forms the basis of a lot of flow-based architectures (as of 2019), so this is a good template to start learning from.

If you are not already familiar with flows at a high level, please check out the 2-part tutorial: [part 1] [part 2], as this tutorial just focuses on how to implement flows in JAX. You can find all the code along with the slides for my talk here.

Install Dependencies

There are just a few dependencies required to reproduce this tutorial. We'll be running everything on the CPU, though you can also build the GPU-enabled versions of JAX if you have the requisite hardware.

pip install --upgrade jax jaxlib scikit-learn matplotlib

Toy Dataset

Scikit-Learn comes with some toy datasets that are useful for small scale density models.

from sklearn import cluster, datasets, mixture
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
n_samples = 2000
noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05)
X, y = noisy_moons
X = StandardScaler().fit_transform(X)

Affine Coupling Layer in JAX

TensorFlow probability defines an object-oriented API for building flows, where a "TransformedDistribution" object is given a base "Distribution" object along with a "Bijector" object that implements the invertible transformation. In pseudocode, it goes something like this:

class TransformedDistribution(Distribution):
  def sample(self):
    x = self.base_distribution.sample()
    return self.bijector.forward(x)
  def log_prob(self, y):
    x = self.bijector.inverse(y)
    ildj = self.bijector.inverse_log_det_jacobian(y)
    return self.base_distribution.log_prob(x) + ildj

However, programming in JAX takes on a functional programming philosophy where functions are stateless and classes are eschewed. That's okay: we can still build a similar API in a functional way. To make everything end-to-end differentiable via JAX's grad() operator, it's convenient to put the parameters that we want gradients for as the first argument of every function. Here are the sample and log_prob implementations of the base distribution.

def sample_n01(N):
  D = 2
  return random.normal(rng, (N, D))
def log_prob_n01(x):
  return np.sum(-np.square(x)/2 - np.log(np.sqrt(2*np.pi)),axis=-1)

Below are the forward and inverse functions of Real-NVP, which operates on minibatches (we could also re-implement this to operate over vectors, and use JAX's vmap operator to auto-batch it). Because we are dealing with 2D data, the masking scheme for Real-NVP is very simple: we just switch the masked variable every other flow via the "flip" parameter.

def nvp_forward(net_params, shift_and_log_scale_fn, x, flip=False):
  d = x.shape[-1]//2
  x1, x2 = x[:, :d], x[:, d:]
  if flip:
    x2, x1 = x1, x2
  shift, log_scale = shift_and_log_scale_fn(net_params, x1)
  y2 = x2*np.exp(log_scale) + shift
  if flip:
    x1, y2 = y2, x1
  y = np.concatenate([x1, y2], axis=-1)
  return y

def nvp_inverse(net_params, shift_and_log_scale_fn, y, flip=False):
  d = y.shape[-1]//2
  y1, y2 = y[:, :d], y[:, d:]
  if flip:
    y1, y2 = y2, y1
  shift, log_scale = shift_and_log_scale_fn(net_params, y1)
  x2 = (y2-shift)*np.exp(-log_scale)
  if flip:
    y1, x2 = x2, y1
  x = np.concatenate([y1, x2], axis=-1)
  return x, log_scale

The "forward" NVP transformation takes in a callable shift_and_log_scale_fn (an arbitrary neural net that takes the masked variables as inputs), applies it to recover the shift and log scale parameters, transforms the un-masked inputs, and then stitches the masked scalar and the transformed scalar back together in the right order. The inverse does the opposite. 

Here are the corresponding sampling (forward) and log-prob (inverse) implementations for a single RealNVP coupling layer. The ILDJ term is computed directly, as it is just the (negative) sum of the log_scale terms.

def sample_nvp(net_params, shift_log_scale_fn, base_sample_fn, N, flip=False):
  x = base_sample_fn(N)
  return nvp_forward(net_params, shift_log_scale_fn, x, flip=flip)

def log_prob_nvp(net_params, shift_log_scale_fn, base_log_prob_fn, y, flip=False):
  x, log_scale = nvp_inverse(net_params, shift_log_scale_fn, y, flip=flip)
  ildj = -np.sum(log_scale, axis=-1)
  return base_log_prob_fn(x) + ildj

What should we use for our shift_and_log_scale_fn? I've found that for 2D data + NVP, wider and shallow neural nets tend to train more stably. We'll use some JAX helper libraries to build a function that initializes the parameters and callable function for a MLP with two hidden layers (512) and ReLU activations. 

from jax.experimental import stax # neural network library
from jax.experimental.stax import Dense, Relu # neural network layers

def init_nvp():
  D = 2
  net_init, net_apply = stax.serial(
    Dense(512), Relu, Dense(512), Relu, Dense(D))
  in_shape = (-1, D//2)
  out_shape, net_params = net_init(rng, in_shape)
  def shift_and_log_scale_fn(net_params, x1):
    s = net_apply(net_params, x1)
    return np.split(s, 2, axis=1)
  return net_params, shift_and_log_scale_fn

Stacking Coupling Layers

TensorFlow Probability's object-oriented API is convenient because it allows us to "stack" multiple TransformedDistributions on top of each other for more expressive - yet tractable - transformations. 

dist1 = TransformedDistribution(base_dist, bijector1)
dist2 = TransformedDistribtution(dist1, bijector2)
dist2.sample() # member variables reference dist1, which references base_dist

For "bipartite" flows like Real-NVP which leave some variables untouched, it is critical to be able to stack multiple flows so that all variables get a chance to be "transformed". 

Here's the functional way to do the same thing in JAX. We have a function "init_nvp_chain" that returns neural net parameters, callable shift_and_log_scale_fns, and masking parameters for each flow. We then pass this big bag of parameters to the sample_nvp_chain function. 

In log_prob_nvp_chain, there is an iteration loop that overrides log_prob_fn, which is initially set to base_log_prob_fn. This is to accomplish similar semantics to how TransformedDistribution.log_prob is defined with respect to the log_prob function of the base distribution beneath it. Python variable binding can be a bit tricky at times, and it's easy to make a mistake here that results in an infinite loop. The solution is to make a function generator (make_lob_prob_fn), that returns a function with the correct base log_prob_fn bound to the log_prob_nvp argument. Thanks to David Bieber for pointing this fix out to me.

def init_nvp_chain(n=2):
  flip = False
  ps, configs = [], []
  for i in range(n):
    p, f = init_nvp()
    ps.append(p), configs.append((f, flip))
    flip = not flip
  return ps, configs

def sample_nvp_chain(ps, configs, base_sample_fn, N):
  x = base_sample_fn(N)
  for p, config in zip(ps, configs):
    shift_log_scale_fn, flip = config
    x = nvp_forward(p, shift_log_scale_fn, x, flip=flip)
  return x

def make_log_prob_fn(p, log_prob_fn, config):
  shift_log_scale_fn, flip = config
  return lambda x: log_prob_nvp(p, shift_log_scale_fn, log_prob_fn, x, flip=flip)

def log_prob_nvp_chain(ps, configs, base_log_prob_fn, y):
  log_prob_fn = base_log_prob_fn
  for p, config in zip(ps, configs):
    log_prob_fn = make_log_prob_fn(p, log_prob_fn, config)
  return log_prob_fn(y)

Training Real-NVP

Finally, we are ready to train this thing! 

We initialize our Real-NVP with 4 affine coupling layers (each variable is transformed twice), define the optimization objective to be model negative log-likelihood over minibatches (more precisely, cross entropy). 

from jax.experimental import optimizers
from jax import jit, grad
import numpy as onp
ps, cs = init_nvp_chain(4)

def loss(params, batch):
  return -np.mean(log_prob_nvp_chain(params, cs, log_prob_n01, batch))
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-4)

Next, we declare a single optimization step where we retrieve the current optimizer state, compute gradients with respect to our big list of Real-NVP parameters, and then update our parameters. The cool thing about JAX is that we can "jit" (just-in-time compile) the step function to a single XLA op so that the entire optimization step happens without returning back to the (relatively slow) Python interpreter. We could even JIT the entire optimization process if we wanted to!

def step(i, opt_state, batch):
  params = get_params(opt_state)
  g = grad(loss)(params, batch)
  return opt_update(i, g, opt_state)

iters = int(1e4)
data_generator = (X[onp.random.choice(X.shape[0], 100)] for _ in range(iters))
opt_state = opt_init(ps)
for i in range(iters):
  opt_state = step(i, opt_state, next(data_generator))
ps = get_params(opt_state)


Here's the code snippet that will visualize each of the 4 affine coupling layers transforming samples from the base Normal distribution, in sequence. Is it just me, or does anyone else find themselves constantly having to Google "How to make a Matplotlib animation?"

from matplotlib import animation, rc
from IPython.display import HTML, Image

x = sample_n01(1000)
values = [x]
for p, config in zip(ps, cs):
  shift_log_scale_fn, flip = config
  x = nvp_forward(p, shift_log_scale_fn, x, flip=flip)

# First set up the figure, the axis, and the plot element we want to animate
fig, ax = plt.subplots()

y = values[0]
paths = ax.scatter(y[:, 0], y[:, 1], s=10, color='red')

def animate(i):
  l = i//48
  t = (float(i%48))/48
  y = (1-t)*values[l] + t*values[l+1]
  return (paths,)
anim = animation.FuncAnimation(fig, animate, frames=48*len(cs), interval=1, blit=False)'anim.gif', writer='imagemagick', fps=60)