Thursday, February 21, 2019

Meta-Learning in 50 Lines of JAX

Github repo here:

Adaptive behavior in humans and animals occurs at many time scales: when I use a new shower handle for the first time, it takes me a few seconds to figure out how to adjust the water temperature to my liking. Upon reading a news article, I obtain new information that I didn't have before. More difficult skills, such as mastering a musical instrument, are acquired over a lifetime of deliberate practice.

Learning is hardly restricted to animal-level intelligence; it can be found in every living creature. Multi-cellular developmental programs are highly plastic and can even store epigenetic “memories'” between generations. At the longest time-scales, evolution itself can be thought of as “learning” on the genomic level, whereby favorable genetic codes are discovered and remembered over the course of many generations. At the shortest of timescales, a single ion channel activating in response to a stimulus can also be thought of as “learning”, as it is an adaptive, stateful response to the environment. Biological intelligence blurs the boundaries between “behavior” (responding to the environment), “learning” (acquiring information about the world in order to improve fitness), and “optimization” (improving fitness).

The focus of Machine Learning (ML) is to imbue computers with the ability to learn from data, so that they may accomplish tasks that humans have difficulty expressing in pure code. However, what most ML researchers call “learning” right now is but a very small subset of the vast range of behavioral adaptability encountered in biological life! Deep Learning models are powerful, but require a large amount of data and many iterations of stochastic gradient descent (SGD). This learning procedure is time-consuming and once a deep model is trained, its behavior is fairly rigid; at deployment time, one cannot really change the behavior of the system (e.g. correcting mistakes) without an expensive retraining process. Can we build systems that can learn faster, and with less data?

“Meta-learning'', one of the most exciting ML research topics right now, addresses this problem by optimizing a model not just for the ability to “predict well'', but also the ability to “learn well''. Although Meta-Learning has attracted a lot of research attention in recent years, related ideas and algorithms have been around for some time (see Hugo Larochelle's slides and Lilian Weng’s blog post for an excellent overview of related concepts).

This blog post won’t cover all the possible ways in which one can build a meta-learning system; instead, this is a practical tutorial on how to get your feet wet in meta-learning research. Specifically, I'll show you how to implement the MAML meta-learning algorithm in about 50 lines of Python code, using Google's awesome JAX library.

You can find a self-contained Jupyter notebook here reproducing this tutorial.

An Operator Perspective on Learning and Meta-Learning

“Meta-learning” is used in so many different research contexts nowadays that it's difficult to communicate to other researchers what I’m exactly working on when I say “Meta-Learning”. A source of this confusion stems from the blurred semantics between “optimization”, “learning”, “adaptation”, “memory”, and how these terms can be employed in wildly different applications.

This section is my attempt to make the definition of “learning” and “meta-learning” more mathematically precise, and explain why seemingly different algorithms are all branded as “meta-learning” these days. Feel free to skip to the next section if you want to dive straight into the MAML+JAX coding tutorial.

We define a learning operator $f : F_\theta \to F_\theta$ as a function that improves a model function $f_\theta$ with respect to some task. A common learning operator used in deep learning and reinforcement learning literature is the stochastic gradient descent algorithm, with respect to a loss function. In standard DL contexts, learning occurs over hundreds of thousands or even millions of gradient steps, but generally, “learning'' can also occur on shorter (conditioning) or longer timescales (hyperparameter search). In addition to explicit optimization, learning can also be implemented implicitly via a dynamical system (recurrent neural networks conditioning on the past) or probabilistic inference.

A meta-learning operator $f_o(f_i(f_\theta))$ is a composite operator of two learning operators: an “inner loop'' $f_i \in F_i$ and an “outer loop'' $f_o \in F_o$. Furthermore, $f_i$ is a model itself, and $f_o : F_i \to F_i$ is an operator over the inner learning rule $f_i$. In other words, $f_o$ learns the learning rule $f_i$, and $f_i$ learns a model for a given task, where we define “task'' to be a self-contained family of problems for which $f_i$ can adequately update $f_\theta$ to solve. At meta-training time, $f_o$ is applied to select for $f_i$ across a variety of training tasks. At meta-test time, we evaluate the generalization properties of $f_i$ and $f_\theta$ to holdout tasks.

The choice of $f_o$ and $f_i$ depends largely on the problem domain. In architecture search literature (also called “learning to learn''), $f_i$ is a relatively slow training procedure of a neural network from scratch, while $f_o$ can be a neural controller, random search algorithm, or a Gaussian Process Bandit.

A wide variety of machine learning problems can be formulated in terms meta-learning operators. In (meta) imitation learning (or goal-conditioned reinforcement learning), $f_i$ is used to relay instructions to the RL agent, such as conditioning on a task embedding or human demonstrations. In meta-reinforcement learning (MRL), $f_i$ instead implements a “fast reinforcement learning'' algorithm by which an agent improves itself after trying the task a couple times. It’s worth re-iterating here that I don’t see a distinction between “learning” and “conditioning”, because they both rely on inputs that are supplied at test time (i.e. “new information provided by the environment”).

MAML is a meta-learning algorithm that implements $f_i$ via SGD, i.e. $\theta := \theta - \alpha \nabla_{\theta}(\mathcal{L}(\theta))$. This SGD update is differentiable with respect to $\theta$, allowing $f_o$ to effectively optimize $f_i$ via backpropagation without requiring many additional parameters to express $f_i$.

Exploring JAX: Gradients

We begin the tutorial by importing JAX’s numpy drop-in and the gradient operator, grad.

import jax.numpy as np
from jax import grad

The gradient operator grad transforms a python function into another function that computes the gradients. Here, we compute first, second, and third order derivatives of $e^x$ and $x^2$:

f = lambda x : np.exp(x)
g = lambda x : np.square(x)
print(grad(f)(1.)) # = e^{1}

print(grad(g)(2.)) # 2x = 4
print(grad(grad(g))(2.)) # x = 2
print(grad(grad(grad(g)))(2.)) # x = 0

Exploring JAX: Auto-Vectorization with vmap

Now let’s consider a toy regression problem in which we try to learn the function $f_\theta(x) = sin(x)$ with a neural network. The goal here is to get familiar with defining and training models. JAX provides some lightweight helper functions to make it easy to set up a neural network.

from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization

We’ll define a simple neural network with 2 hidden layers. We’ve specified an in_shape of (-1, 1), which means that the model takes in a variable-size batch dimension, and has a feature dimension of 1 scalar (since this is a 1-D regression task). JAX’s helper libraries all take on a functional API (unlike TensorFlow, which maintains a graph state), so we get back a function that initializes parameters and a function that applies the forward pass of the network. These callables return lists and tuples of numpy arrays - a simple and flat data structure for storing network parameters.

# Use stax to set up network initialization and evaluation functions
net_init, net_apply = stax.serial(
   Dense(40), Relu,
   Dense(40), Relu,
in_shape = (-1, 1,)
out_shape, net_params = net_init(in_shape)

Next, we define the model loss to be Mean-Squared Error (MSE) across a batch of inputs.

def loss(params, inputs, targets):
   # Computes average loss for the batch
   predictions = net_apply(params, inputs)
   return np.mean((targets - predictions)**2)

We evaluate the uninitialized network across a range of inputs:

# batch the inference across K=100
xrange_inputs = np.linspace(-5,5,100).reshape((100, 1)) # (k, 1)
targets = np.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss
plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')

As expected, at random initialization, the model’s predictions (blue) are totally off the target function (green).

Let’s train the network via gradient descent. JAX’s random number generator is set up differently than Numpy’s, so to initialize network parameters we’ll use the original Numpy library (onp) to generate random numbers. We’ll also import the tree_multimap utility to easily manipulate collections of per-parameter gradients (for TensorFlow users, this is analogous to nest.map_structure for Tensors).

import numpy as onp
from jax.experimental import optimizers
from jax.tree_util import tree_multimap  # Element-wise manipulation of collections of numpy arrays

We initialize the parameters and optimizer, and run the curve fitting for 100 steps. Note that adding the @jit decorator to the “step” function uses XLA to compile the entire training step into machine code, along with optimizations like fused accelerator kernels, memory and layout optimization. TensorFlow itself also uses XLA for accelerating statically defined graphs. XLA makes the computation very fast and amenable to hardware acceleration because the entire thing can be executed without returning to a Python interpreter (or Graph interpreter in the case of TensorFlow sans XLA). The code in this tutorial will just work on CPU/GPU/TPU.

opt_init, opt_update = optimizers.adam(step_size=1e-2)
opt_state = opt_init(net_params)
# Define a compiled update step
def step(i, opt_state, x1, y1):
   p = optimizers.get_params(opt_state)
   g = grad(loss)(p, x1, y1)
   return opt_update(i, g, opt_state)

for i in range(100):
   opt_state = step(i, opt_state, xrange_inputs, targets)
net_params = optimizers.get_params(opt_state)

Evaluating our network again, we see that the sinusoid curve has been correctly approximated.

This result is nothing to write home about, but in just a moment we’ll re-use a lot of these functions to implement MAML.

Exploring JAX: Checking MAML Numerics

When implementing ML algorithms, it’s important to unit-testing implementations against test cases where the true values can be computed analytically. The following example does this for MAML on a toy objective $g$. Note that by default JAX computes gradients with respect to the first argument of the function.

# gradients of gradients test for MAML
# check numerics
g = lambda x, y : np.square(x) + y
x0 = 2.
y0 = 1.
print('grad(g)(x0) = {}'.format(grad(g)(x0, y0))) # 2x = 4
print('x0 - grad(g)(x0) = {}'.format(x0 - grad(g)(x0, y0))) # x - 2x = -2
def maml_objective(x, y):
   return g(x - grad(g)(x, y), y)
print('maml_objective(x,y)={}'.format(maml_objective(x0, y0))) # x**2 + 1 = 5
print('x0 - maml_objective(x,y) = {}'.format(x0 - grad(maml_objective)(x0, y0))) # x - (2x) = -2.

Implementing MAML with JAX

Now let’s extend our sinusoid regression task to a multi-task problem, in which the sinusoid function can have varying phases and amplitudes. This task was proposed in the MAML paper as a way to illustrate how MAML works on a toy problem. Below are some points sampled from two different tasks, divided into “train” (used to compute the inner loss) and “validation” splits (sampled from the same task, used to compute the outer loss).

Suppose a task loss function $\mathcal{L}$ is defined with respect to model parameters $\theta$, input features $X$, input labels $Y$. Let $x_1, y_2$ and $x_2, y_2$ be identically distributed task instance data sampled from $X, Y$. Then MAML optimizes the following:

$\mathcal{L}(\theta - \nabla \mathcal{L}(\theta, x_1, y_1), x_2, y_2)$

MAML’s inner update operator is just gradient descent on the regression loss. The outer loss, maml_loss, is simply the original loss applied after the inner_update operator has been applied. One interpretation of the MAML objective is that it is a differentiable estimate of a cross-validation loss with respect to a learner. Meta-training results in an inner_update that minimizes the cross-validation loss.

def inner_update(p, x1, y1, alpha=.1):
   grads = grad(loss)(p, x1, y1)
   inner_sgd_fn = lambda g, state: (state - alpha*g)
   return tree_multimap(inner_sgd_fn, grads, p)

def maml_loss(p, x1, y1, x2, y2):
   p2 = inner_update(p, x1, y1)
   return loss(p2, x2, y2)

In each iteration of optimizing the MAML objective, we sample a single new task, sample a different set of input features and input labels for both the training and validation splits.

opt_init, opt_update = optimizers.adam(step_size=1e-3)  # this LR seems to be better than 1e-2 and 1e-4
out_shape, net_params = net_init(in_shape)
opt_state = opt_init(net_params)

def step(i, opt_state, x1, y1, x2, y2):
   p = optimizers.get_params(opt_state)
   g = grad(maml_loss)(p, x1, y1, x2, y2)
   l = maml_loss(p, x1, y1, x2, y2)
   return opt_update(i, g, opt_state), l

np_maml_loss = []

# Adam optimization
for i in range(20000):
   # define the task
   A = onp.random.uniform(low=0.1, high=.5)
   phase = onp.random.uniform(low=0., high=np.pi)
   # meta-training inner split (K examples)
   x1 = onp.random.uniform(low=-5., high=5., size=(K,1))
   y1 = A * onp.sin(x1 + phase)
   # meta-training outer split (1 example). Like cross-validating with respect to one example.
   x2 = onp.random.uniform(low=-5., high=5.)
   y2 = A * onp.sin(x2 + phase)
   opt_state, l = step(i, opt_state, x1, y1, x2, y2)
   if i % 1000 == 0:
net_params = optimizers.get_params(opt_state)

At meta-training time, the network learns to “quickly adapt” to x1, y1 in order to minimize cross-validation error on a new set of points x2. At deployment time (shown in the plot above), when we have a new task (new amplitude and phase not seen at training time), the model can apply the inner_update operator to fit the target sinusoid much faster and with fewer data samples than simply re-training the parameters with SGD.

Why is inner_update a more effective learning rule than retraining with SGD on a new dataset? The magic here is that by training in a multi-task setting, the inner_update operator has generalized across tasks into a learning rule that is specially adapted for sinusoid regression tasks. In the standard data regime of deep learning, generalization is obtained from many examples of a single task (e.g. RL, image classification). In meta-learning, generalization is obtained from a few examples each from many tasks, and a shared learning rule is learned for the task distribution.

# batch the inference across K=100
targets = np.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
plt.plot(xrange_inputs, predictions, label='pre-update predictions')
plt.plot(xrange_inputs, targets, label='target')

x1 = onp.random.uniform(low=-5., high=5., size=(K,1))
y1 = 1. * onp.sin(x1 + 0.)

for i in range(1,5):
   net_params = inner_update(net_params, x1, y1)
   predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
   plt.plot(xrange_inputs, predictions, label='{}-shot predictions'.format(i))

Batching MAML Gradients Across Tasks with vmap

We can compute the MAML gradients across multiple tasks at once to reduce the variance of gradients of the learning operator. This was proposed in the MAML paper, and is analogous to how increasing minibatch size of standard SGD reduces variance of the parameter gradients (leading to more efficient learning).

Thanks to the vmap operator, we can automatically transform our single-task MAML implementation into a “batched version” that operates across tasks. From a software engineering & testing perspective, vmap is extremely nice because the "task-batched" MAML implementation simply re-uses code from the non-task batched MAML algorithm, without losing any vectorization benefits. This means that when unit-testing code, we can test the single-task MAML algorithm for numerical correctness, then scale up to a more advanced batched version (e.g. for handling harder tasks such as robotic learning) for efficiency. 

# vmapped version of maml loss.
# returns scalar for all tasks.
def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
   task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)
   return np.mean(task_losses)

Below is a function that samples a batch of tasks, where outer_batch_size is the number of tasks we meta-train on in each step, and inner_batch_size is the number of data points per-task. 

def sample_tasks(outer_batch_size, inner_batch_size):
   # Select amplitude and phase for the task
   As = []
   phases = []
   for _ in range(outer_batch_size):        
       As.append(onp.random.uniform(low=0.1, high=.5))
       phases.append(onp.random.uniform(low=0., high=np.pi))
   def get_batch():
       xs, ys = [], []
       for A, phase in zip(As, phases):
           x = onp.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
           y = A * onp.sin(x + phase)
       return np.stack(xs), np.stack(ys)
   x1, y1 = get_batch()
   x2, y2 = get_batch()
   return x1, y1, x2, y2

Now for the training loop, which strongly resembles the previous single-task one. As you can see, gradient-based meta-learning requires treating two kinds of variance: those of intra-task gradients for the inner loss, and those of inter-task gradients for the outer loss.

opt_init, opt_update = optimizers.adam(step_size=1e-3)
out_shape, net_params = net_init(in_shape)
opt_state = opt_init(net_params)

# vmapped version of maml loss.
# returns scalar for all tasks.
def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
   task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)
   return np.mean(task_losses)

def step(i, opt_state, x1, y1, x2, y2):
   p = optimizers.get_params(opt_state)
   g = grad(batch_maml_loss)(p, x1, y1, x2, y2)
   l = batch_maml_loss(p, x1, y1, x2, y2)
   return opt_update(i, g, opt_state), l

np_batched_maml_loss = []
for i in range(20000):
   x1_b, y1_b, x2_b, y2_b = sample_tasks(4, K)
   opt_state, l = step(i, opt_state, x1_b, y1_b, x2_b, y2_b)
   if i % 1000 == 0:
net_params = optimizers.get_params(opt_state)

When we plot the MAML objective as a function of training step, we see that the batched MAML trains much faster (as a function of gradient steps) and also has lower variance during training.


In this tutorial we explored the MAML algorithm and reproduced the Sinusoid regression task from the paper in about 50 lines of Python code. I was very pleasantly surprised to find how easy grad, vmap, and jit made it to implement MAML, and I am excited to continue using it for my own meta-learning research.

So, what are the distinctions between “optimization”, “learning”, “adaptation”, and “memory”? I believe they are all equivalent, because it is possible to implement memory capabilities with optimization techniques (MAML) and vice versa (e.g. RNN-based meta reinforcement learning). In reinforcement learning, imitating a teacher or conditioning on user-specified goal or recovering from a failure can all use the same machinery.

Thinking about precise definitions of “learning” and “meta-learning”, and attempting to reconcile them with the capabilities of biological intelligence have led me to realize that every process in Life itself, spanning molecular reaction to behavioral adaptation to genetic evolution, is nothing more than learning happening at many time scales. I’ll have much more to say on the topic of Artificial Life and Machine Learning in the future, but for now, thank you for reading this humble tutorial on fitting sinusoidal functions!


Thanks to Matthew Johnson for helping to proofread this post and helping me to resolve JAX questions.