Sunday, March 10, 2019

What I Cannot Control, I Do not Understand

Xiaoyi Yin has graciously translated this blog post to 中文.

I often hear the remark around the proverbial AI watering hole that there are no examples of reinforcement learning (RL) deployed in commercial settings that couldn’t be replaced by simpler algorithms.

This is somewhat true. If one takes RL to mean “neural networks trained with DQN / PPO / Soft-Actor Critic etc.”, then indeed, there are no commercial products (yet!) whose success relies on Deep RL algorithmic breakthroughs in the last 5 years [1].

However, if one interprets “reinforcement learning” to mean the notion of “learning from repeated trial and error”, then commercial applications abound, especially in pharmaceuticals, finance, TV show recommendations, and other endeavors based on scientific experimentation and intervention.

I’ll explain in this post how Reinforcement Learning is a general approach to solving the Causal Inference problem, the desiderata of nearly all machine learning systems. In this sense, many high-impact problems are already tackled using ideas from RL, but under different terminology and engineering processes.

Doctor, Won’t You Help Me Live Longer


Let’s suppose you are a doctor tasked with helping your patients live longer. You know a thing or two about data science, so you fit a model on a lot of patient records to predict life expectancy, and make a shocking finding: people who drink red wine every day have a 90% likelihood of living over 80 years, compared to the base probability of 50% for non drinkers.

In the parlance of causal inference, you’ve found the following observational distribution:

p(patient lives > 80 yrs | patient drinks red wine daily) = .9

Furthermore, your model has high accuracy on holdout datasets, which increases your confidence that your model has discovered the secret to longevity. Elated, you start telling your patients to drink red wine daily. After all, as a doctor, it is insufficient to predict; we must also prescribe! And what’s not to like about living longer and drinking red wine on the daily?

Many decades later, you follow up on your patients and -- with great disappointment -- observe the following interventional distribution:

p(patient lives > 80 yrs | do(patient drinks red wine daily)) = .5

The life expectancy of patients on the red wine has not increased! What gives?

Finding the Causal Model

The core problem here lies in confounding variables. When we decided to prescribe red wine to patients based on the observational model, we made a strong hypothesis about the causality diagram:



The directed edges between these random variables here denote causality, which can also be thought of as "the arrow of time". Changing the value of the “Drinks Red Wine” variable ought to have an effect on “Live > 80 years”, but changing “Lives > 80 years” has no effect on drinking red wine.

If this causal diagram was correct, then our intervention should have increased the lifespan of patients. But the actual experiment does not support this, so we must reject this hypothetical causal model and reach for alternative hypotheses to explain the data. Perhaps there are one or more variables that cause a higher propensity of red wine drinking, AND living longer, thus correlating those variables together?

We make the educated guess that a confounding variable might be that wealthy people tend to simultaneously live longer and drink more wine. Combing through the data again, we find that P(drinks red wine | is wealthy) = 0.9 and P(lives > 80 | is wealthy) = 1.0. So our hypothesis now takes the form:



If our understanding of the world is correct, then do(is wealthy) should make people live > 80 years and drink more red wine. And indeed, we find that once we give patients $1M cash infusions to make them wealthy (by USA standards), they end up living longer and drinking red wine daily (this is a hypothetical result, fabricated for the sake of this blog post).

RL as Automated Causal Inference

ML models are increasingly used to drive decision making in recommender systems, self-driving cars, pharmaceutical R&D, and experimental physics. In many cases, we desire an outcome event $y$, for which we attempt to learn a model $p(y|x_1, .., x_N)$ and then choose inputs $x_1...x_N$ to maximize $p(y|x_1...x_N)$.

It should be quite obvious from the previous medical example that to avoid causality when building decision-making systems is to risk overfitting models that are not useful for prescribing intervention. Suppose we automated the causal model discovery process in the following manner:
  1. Fit an observational model to the data p(y|x_1, x_2, … x_N)
  2. Assume the observational model captures the causal model. Prescribe an intervention do(x_i) that maximizes p(y|x_1..N) and gather a new dataset where 50% of x_i has the intervention and 50% does not.
  3. Fit an observational model to the new data p(y|x_i)
  4. Repeat steps 1-3 until observational model matches intervention model: p(y|do(x_i)) = p(y|x_i)
To return to the red wine case study as a test case:
  1. You would initially have p(live > 80 years | drink red wine daily) = .9. 
  2. Upon gathering a new dataset, you would obtain p(live > 80 years | do(drink red wine daily)) = .5. Model is not converged, but at least your observational model no longer believes that drinking red wine explains living longer. Furthermore, it now pays attention to the right variable, that p(live > 80 years | is_wealthy) = 1.
  3. The subsequent iteration of this procedure then finds that p(live > 80 years | do(is wealthy)) = 1, so we are done.

The act of gathering a randomized trial (the 50% split of intervention vs. non-intervention) and re-training a new observational model is one of the most powerful ways to do general causal inference, because it uses data from reality (which “knows” the true causal model) to stamp out incorrect hypotheses.

Repeatedly training observational models and suggesting interventions is what RL algorithms are all about, which is solving optimal control for sequential decision-making problems. Control is the operative word here - the true test of whether an agent understands its environment is whether it can solve it.

For ML models whose predictions are used to infer interventions (so as to manipulate some downstream random variable), I argue that the overfitting problem is nothing more than a causal inference problem. This also explains why RL tends to be much harder as a machine learning problem than supervised learning - not only are there fewer bits of supervision per observation, but the RL agent must also figure out the causal, interventionist distribution required to behave optimally.

One salient case of “overfitting” is in RL algorithms can theoretically be trained “offline” -- that is, learning entirely from off-policy data without gathering new data samples from the environment. However, without periodically gathering new experience from the environment, agents can overfit to finite-size datasets or dataset imbalances, and propose interventions that do not generalize past their offline data. The best way to check if an agent is “learning the right thing” is to deploy it in the world and verify its hypotheses under the interventionist distribution. Indeed, for our robotic grasping research at Google, we often find that fine-tuning with “online” experience improves performance substantially. This is equivalent to re-training an observational model on new data p(grasp success | do(optimal_action)).


Production "RL"

The A/B testing framework often used in production engineering is a manual version of the "automated causal inference" pipeline, where a random 50% of users (assumed to be identically distributed) are shown one intervention and the other 50% are shown the control.

This is the cornerstone of data-driven decision making, and is used widely at hedge funds, Netflix, StitchFix, Google, Walmart, and so on. Although this process has humans in the loop (specifically for proposing interventions and choosing the stopping criterion), there are many related nuances to these methodologies that also arise in RL literature like data non-stationarity, the difficulty of obtaining truly randomized experiments, and long-term credit assignment. I’m just starting to learn about causal inference myself, and hope that in the next few years there will be more cross-fertilization of ideas between the RL, Data Science, and Causal Inference research communities.

For a more technical introduction to Causal Inference, see this great blog series by Ferenc Huszar.



[1] A footnote on why I think RL hasn’t had much commercial deployment yet. Feel free to clue me in if there are indeed companies using RL in production that I don’t know about!

In order for a company to be justified in adopting RL technology, the problem at hand needs to be 1) commercially useful 2) feasible for current Deep RL algorithms 3) the marginal utility of optimal control must be worth the technical risks of Deep RL.

Let’s consider deep image understanding by comparison: 1) everything from surveillance to self-driving cars to FaceID is highly commercially interesting 2) current models are highly accurate and scale well to a variety of image datasets 3) the models generally work as expected and do not require great expertise to train and deploy.

As for RL, it doesn’t take a great imagination to realize that general RL algorithms would eventually enable robots to learn skills entirely on their own, or help companies make complex financial decisions like stock buybacks and hiring, or enable far richer NPC behavior in games. Unfortunately, these problem domains don’t meet criteria (2) - the technology simply isn’t ready and requires many more years of R&D.

For problems where RL is plausible, it is difficult to justify being the first user of a technology whose marginal utility to your problem of choice is unproven. Example problems might include datacenter cooling or air traffic control. Even for domains where RL has been shown clearly to work (e.g. low-dimensional control or pixel-level control), RL still requires a lot of research skill to build a working system.

Thursday, February 21, 2019

Meta-Learning in 50 Lines of JAX

Github repo here: https://github.com/ericjang/maml-jax

Adaptive behavior in humans and animals occurs at many time scales: when I use a new shower handle for the first time, it takes me a few seconds to figure out how to adjust the water temperature to my liking. Upon reading a news article, I obtain new information that I didn't have before. More difficult skills, such as mastering a musical instrument, are acquired over a lifetime of deliberate practice.

Learning is hardly restricted to animal-level intelligence; it can be found in every living creature. Multi-cellular developmental programs are highly plastic and can even store epigenetic “memories'” between generations. At the longest time-scales, evolution itself can be thought of as “learning” on the genomic level, whereby favorable genetic codes are discovered and remembered over the course of many generations. At the shortest of timescales, a single ion channel activating in response to a stimulus can also be thought of as “learning”, as it is an adaptive, stateful response to the environment. Biological intelligence blurs the boundaries between “behavior” (responding to the environment), “learning” (acquiring information about the world in order to improve fitness), and “optimization” (improving fitness).

The focus of Machine Learning (ML) is to imbue computers with the ability to learn from data, so that they may accomplish tasks that humans have difficulty expressing in pure code. However, what most ML researchers call “learning” right now is but a very small subset of the vast range of behavioral adaptability encountered in biological life! Deep Learning models are powerful, but require a large amount of data and many iterations of stochastic gradient descent (SGD). This learning procedure is time-consuming and once a deep model is trained, its behavior is fairly rigid; at deployment time, one cannot really change the behavior of the system (e.g. correcting mistakes) without an expensive retraining process. Can we build systems that can learn faster, and with less data?

“Meta-learning'', one of the most exciting ML research topics right now, addresses this problem by optimizing a model not just for the ability to “predict well'', but also the ability to “learn well''. Although Meta-Learning has attracted a lot of research attention in recent years, related ideas and algorithms have been around for some time (see Hugo Larochelle's slides and Lilian Weng’s blog post for an excellent overview of related concepts).

This blog post won’t cover all the possible ways in which one can build a meta-learning system; instead, this is a practical tutorial on how to get your feet wet in meta-learning research. Specifically, I'll show you how to implement the MAML meta-learning algorithm in about 50 lines of Python code, using Google's awesome JAX library.

You can find a self-contained Jupyter notebook here reproducing this tutorial.

An Operator Perspective on Learning and Meta-Learning


“Meta-learning” is used in so many different research contexts nowadays that it's difficult to communicate to other researchers what I’m exactly working on when I say “Meta-Learning”. A source of this confusion stems from the blurred semantics between “optimization”, “learning”, “adaptation”, “memory”, and how these terms can be employed in wildly different applications.

This section is my attempt to make the definition of “learning” and “meta-learning” more mathematically precise, and explain why seemingly different algorithms are all branded as “meta-learning” these days. Feel free to skip to the next section if you want to dive straight into the MAML+JAX coding tutorial.

We define a learning operator $f : F_\theta \to F_\theta$ as a function that improves a model function $f_\theta$ with respect to some task. A common learning operator used in deep learning and reinforcement learning literature is the stochastic gradient descent algorithm, with respect to a loss function. In standard DL contexts, learning occurs over hundreds of thousands or even millions of gradient steps, but generally, “learning'' can also occur on shorter (conditioning) or longer timescales (hyperparameter search). In addition to explicit optimization, learning can also be implemented implicitly via a dynamical system (recurrent neural networks conditioning on the past) or probabilistic inference.

A meta-learning operator $f_o(f_i(f_\theta))$ is a composite operator of two learning operators: an “inner loop'' $f_i \in F_i$ and an “outer loop'' $f_o \in F_o$. Furthermore, $f_i$ is a model itself, and $f_o : F_i \to F_i$ is an operator over the inner learning rule $f_i$. In other words, $f_o$ learns the learning rule $f_i$, and $f_i$ learns a model for a given task, where we define “task'' to be a self-contained family of problems for which $f_i$ can adequately update $f_\theta$ to solve. At meta-training time, $f_o$ is applied to select for $f_i$ across a variety of training tasks. At meta-test time, we evaluate the generalization properties of $f_i$ and $f_\theta$ to holdout tasks.

The choice of $f_o$ and $f_i$ depends largely on the problem domain. In architecture search literature (also called “learning to learn''), $f_i$ is a relatively slow training procedure of a neural network from scratch, while $f_o$ can be a neural controller, random search algorithm, or a Gaussian Process Bandit.

A wide variety of machine learning problems can be formulated in terms meta-learning operators. In (meta) imitation learning (or goal-conditioned reinforcement learning), $f_i$ is used to relay instructions to the RL agent, such as conditioning on a task embedding or human demonstrations. In meta-reinforcement learning (MRL), $f_i$ instead implements a “fast reinforcement learning'' algorithm by which an agent improves itself after trying the task a couple times. It’s worth re-iterating here that I don’t see a distinction between “learning” and “conditioning”, because they both rely on inputs that are supplied at test time (i.e. “new information provided by the environment”).

MAML is a meta-learning algorithm that implements $f_i$ via SGD, i.e. $\theta := \theta - \alpha \nabla_{\theta}(\mathcal{L}(\theta))$. This SGD update is differentiable with respect to $\theta$, allowing $f_o$ to effectively optimize $f_i$ via backpropagation without requiring many additional parameters to express $f_i$.


Exploring JAX: Gradients


We begin the tutorial by importing JAX’s numpy drop-in and the gradient operator, grad.


import jax.numpy as np
from jax import grad

The gradient operator grad transforms a python function into another function that computes the gradients. Here, we compute first, second, and third order derivatives of $e^x$ and $x^2$:



f = lambda x : np.exp(x)
g = lambda x : np.square(x)
print(grad(f)(1.)) # = e^{1}
print(grad(grad(f))(1.))
print(grad(grad(grad(f)))(1.))

print(grad(g)(2.)) # 2x = 4
print(grad(grad(g))(2.)) # x = 2
print(grad(grad(grad(g)))(2.)) # x = 0

Exploring JAX: Auto-Vectorization with vmap


Now let’s consider a toy regression problem in which we try to learn the function $f_\theta(x) = sin(x)$ with a neural network. The goal here is to get familiar with defining and training models. JAX provides some lightweight helper functions to make it easy to set up a neural network.


from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers
import matplotlib.pyplot as plt # visualization


We’ll define a simple neural network with 2 hidden layers. We’ve specified an in_shape of (-1, 1), which means that the model takes in a variable-size batch dimension, and has a feature dimension of 1 scalar (since this is a 1-D regression task). JAX’s helper libraries all take on a functional API (unlike TensorFlow, which maintains a graph state), so we get back a function that initializes parameters and a function that applies the forward pass of the network. These callables return lists and tuples of numpy arrays - a simple and flat data structure for storing network parameters.


# Use stax to set up network initialization and evaluation functions
net_init, net_apply = stax.serial(
   Dense(40), Relu,
   Dense(40), Relu,
   Dense(1)
)
in_shape = (-1, 1,)
out_shape, net_params = net_init(in_shape)

Next, we define the model loss to be Mean-Squared Error (MSE) across a batch of inputs.


def loss(params, inputs, targets):
   # Computes average loss for the batch
   predictions = net_apply(params, inputs)
   return np.mean((targets - predictions)**2)

We evaluate the uninitialized network across a range of inputs:


# batch the inference across K=100
xrange_inputs = np.linspace(-5,5,100).reshape((100, 1)) # (k, 1)
targets = np.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss
plt.plot(xrange_inputs, predictions, label='prediction')
plt.plot(xrange_inputs, losses, label='loss')
plt.plot(xrange_inputs, targets, label='target')
plt.legend()

As expected, at random initialization, the model’s predictions (blue) are totally off the target function (green).




Let’s train the network via gradient descent. JAX’s random number generator is set up differently than Numpy’s, so to initialize network parameters we’ll use the original Numpy library (onp) to generate random numbers. We’ll also import the tree_multimap utility to easily manipulate collections of per-parameter gradients (for TensorFlow users, this is analogous to nest.map_structure for Tensors).

import numpy as onp
from jax.experimental import optimizers
from jax.tree_util import tree_multimap  # Element-wise manipulation of collections of numpy arrays


We initialize the parameters and optimizer, and run the curve fitting for 100 steps. Note that adding the @jit decorator to the “step” function uses XLA to compile the entire training step into machine code, along with optimizations like fused accelerator kernels, memory and layout optimization. TensorFlow itself also uses XLA for accelerating statically defined graphs. XLA makes the computation very fast and amenable to hardware acceleration because the entire thing can be executed without returning to a Python interpreter (or Graph interpreter in the case of TensorFlow sans XLA). The code in this tutorial will just work on CPU/GPU/TPU.

opt_init, opt_update = optimizers.adam(step_size=1e-2)
opt_state = opt_init(net_params)
# Define a compiled update step
@jit
def step(i, opt_state, x1, y1):
   p = optimizers.get_params(opt_state)
   g = grad(loss)(p, x1, y1)
   return opt_update(i, g, opt_state)

for i in range(100):
   opt_state = step(i, opt_state, xrange_inputs, targets)
net_params = optimizers.get_params(opt_state)


Evaluating our network again, we see that the sinusoid curve has been correctly approximated.



This result is nothing to write home about, but in just a moment we’ll re-use a lot of these functions to implement MAML.

Exploring JAX: Checking MAML Numerics



When implementing ML algorithms, it’s important to unit-testing implementations against test cases where the true values can be computed analytically. The following example does this for MAML on a toy objective $g$. Note that by default JAX computes gradients with respect to the first argument of the function.

# gradients of gradients test for MAML
# check numerics
g = lambda x, y : np.square(x) + y
x0 = 2.
y0 = 1.
print('grad(g)(x0) = {}'.format(grad(g)(x0, y0))) # 2x = 4
print('x0 - grad(g)(x0) = {}'.format(x0 - grad(g)(x0, y0))) # x - 2x = -2
def maml_objective(x, y):
   return g(x - grad(g)(x, y), y)
print('maml_objective(x,y)={}'.format(maml_objective(x0, y0))) # x**2 + 1 = 5
print('x0 - maml_objective(x,y) = {}'.format(x0 - grad(maml_objective)(x0, y0))) # x - (2x) = -2.


Implementing MAML with JAX


Now let’s extend our sinusoid regression task to a multi-task problem, in which the sinusoid function can have varying phases and amplitudes. This task was proposed in the MAML paper as a way to illustrate how MAML works on a toy problem. Below are some points sampled from two different tasks, divided into “train” (used to compute the inner loss) and “validation” splits (sampled from the same task, used to compute the outer loss).




Suppose a task loss function $\mathcal{L}$ is defined with respect to model parameters $\theta$, input features $X$, output labels $Y$. Let $x_1, y_1$ and $x_2, y_2$ be identically distributed task instance data sampled from $X, Y$. Then MAML optimizes the following:


$\mathcal{L}(\theta - \nabla \mathcal{L}(\theta, x_1, y_1), x_2, y_2)$


MAML’s inner update operator is just gradient descent on the regression loss. The outer loss, maml_loss, is simply the original loss applied after the inner_update operator has been applied. One interpretation of the MAML objective is that it is a differentiable estimate of a cross-validation loss with respect to a learner. Meta-training results in an inner_update that minimizes the cross-validation loss.


def inner_update(p, x1, y1, alpha=.1):
   grads = grad(loss)(p, x1, y1)
   inner_sgd_fn = lambda g, state: (state - alpha*g)
   return tree_multimap(inner_sgd_fn, grads, p)

def maml_loss(p, x1, y1, x2, y2):
   p2 = inner_update(p, x1, y1)
   return loss(p2, x2, y2)


In each iteration of optimizing the MAML objective, we sample a single new task, sample a different set of input features and input labels for both the training and validation splits.


opt_init, opt_update = optimizers.adam(step_size=1e-3)  # this LR seems to be better than 1e-2 and 1e-4
out_shape, net_params = net_init(in_shape)
opt_state = opt_init(net_params)

@jit
def step(i, opt_state, x1, y1, x2, y2):
   p = optimizers.get_params(opt_state)
   g = grad(maml_loss)(p, x1, y1, x2, y2)
   l = maml_loss(p, x1, y1, x2, y2)
   return opt_update(i, g, opt_state), l
K=20

np_maml_loss = []

# Adam optimization
for i in range(20000):
   # define the task
   A = onp.random.uniform(low=0.1, high=.5)
   phase = onp.random.uniform(low=0., high=np.pi)
   # meta-training inner split (K examples)
   x1 = onp.random.uniform(low=-5., high=5., size=(K,1))
   y1 = A * onp.sin(x1 + phase)
   # meta-training outer split (1 example). Like cross-validating with respect to one example.
   x2 = onp.random.uniform(low=-5., high=5.)
   y2 = A * onp.sin(x2 + phase)
   opt_state, l = step(i, opt_state, x1, y1, x2, y2)
   np_maml_loss.append(l)
   if i % 1000 == 0:
       print(i)
net_params = optimizers.get_params(opt_state)


At meta-training time, the network learns to “quickly adapt” to x1, y1 in order to minimize cross-validation error on a new set of points x2. At deployment time (shown in the plot above), when we have a new task (new amplitude and phase not seen at training time), the model can apply the inner_update operator to fit the target sinusoid much faster and with fewer data samples than simply re-training the parameters with SGD.

Why is inner_update a more effective learning rule than retraining with SGD on a new dataset? The magic here is that by training in a multi-task setting, the inner_update operator has generalized across tasks into a learning rule that is specially adapted for sinusoid regression tasks. In the standard data regime of deep learning, generalization is obtained from many examples of a single task (e.g. RL, image classification). In meta-learning, generalization is obtained from a few examples each from many tasks, and a shared learning rule is learned for the task distribution.



# batch the inference across K=100
targets = np.sin(xrange_inputs)
predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
plt.plot(xrange_inputs, predictions, label='pre-update predictions')
plt.plot(xrange_inputs, targets, label='target')

x1 = onp.random.uniform(low=-5., high=5., size=(K,1))
y1 = 1. * onp.sin(x1 + 0.)

for i in range(1,5):
   net_params = inner_update(net_params, x1, y1)
   predictions = vmap(partial(net_apply, net_params))(xrange_inputs)
   plt.plot(xrange_inputs, predictions, label='{}-shot predictions'.format(i))
plt.legend()


Batching MAML Gradients Across Tasks with vmap

We can compute the MAML gradients across multiple tasks at once to reduce the variance of gradients of the learning operator. This was proposed in the MAML paper, and is analogous to how increasing minibatch size of standard SGD reduces variance of the parameter gradients (leading to more efficient learning).

Thanks to the vmap operator, we can automatically transform our single-task MAML implementation into a “batched version” that operates across tasks. From a software engineering & testing perspective, vmap is extremely nice because the "task-batched" MAML implementation simply re-uses code from the non-task batched MAML algorithm, without losing any vectorization benefits. This means that when unit-testing code, we can test the single-task MAML algorithm for numerical correctness, then scale up to a more advanced batched version (e.g. for handling harder tasks such as robotic learning) for efficiency. 

# vmapped version of maml loss.
# returns scalar for all tasks.
def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
   task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)
   return np.mean(task_losses)


Below is a function that samples a batch of tasks, where outer_batch_size is the number of tasks we meta-train on in each step, and inner_batch_size is the number of data points per-task. 


def sample_tasks(outer_batch_size, inner_batch_size):
   # Select amplitude and phase for the task
   As = []
   phases = []
   for _ in range(outer_batch_size):        
       As.append(onp.random.uniform(low=0.1, high=.5))
       phases.append(onp.random.uniform(low=0., high=np.pi))
   def get_batch():
       xs, ys = [], []
       for A, phase in zip(As, phases):
           x = onp.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
           y = A * onp.sin(x + phase)
           xs.append(x)
           ys.append(y)
       return np.stack(xs), np.stack(ys)
   x1, y1 = get_batch()
   x2, y2 = get_batch()
   return x1, y1, x2, y2

Now for the training loop, which strongly resembles the previous single-task one. As you can see, gradient-based meta-learning requires treating two kinds of variance: those of intra-task gradients for the inner loss, and those of inter-task gradients for the outer loss.

opt_init, opt_update = optimizers.adam(step_size=1e-3)
out_shape, net_params = net_init(in_shape)
opt_state = opt_init(net_params)

# vmapped version of maml loss.
# returns scalar for all tasks.
def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
   task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)
   return np.mean(task_losses)

@jit
def step(i, opt_state, x1, y1, x2, y2):
   p = optimizers.get_params(opt_state)
   g = grad(batch_maml_loss)(p, x1, y1, x2, y2)
   l = batch_maml_loss(p, x1, y1, x2, y2)
   return opt_update(i, g, opt_state), l

np_batched_maml_loss = []
K=20
for i in range(20000):
   x1_b, y1_b, x2_b, y2_b = sample_tasks(4, K)
   opt_state, l = step(i, opt_state, x1_b, y1_b, x2_b, y2_b)
   np_batched_maml_loss.append(l)
   if i % 1000 == 0:
       print(i)
net_params = optimizers.get_params(opt_state)

When we plot the MAML objective as a function of training step, we see that the batched MAML trains much faster (as a function of gradient steps) and also has lower variance during training.



Conclusions


In this tutorial we explored the MAML algorithm and reproduced the Sinusoid regression task from the paper in about 50 lines of Python code. I was very pleasantly surprised to find how easy grad, vmap, and jit made it to implement MAML, and I am excited to continue using it for my own meta-learning research.

So, what are the distinctions between “optimization”, “learning”, “adaptation”, and “memory”? I believe they are all equivalent, because it is possible to implement memory capabilities with optimization techniques (MAML) and vice versa (e.g. RNN-based meta reinforcement learning). In reinforcement learning, imitating a teacher or conditioning on user-specified goal or recovering from a failure can all use the same machinery.

Thinking about precise definitions of “learning” and “meta-learning”, and attempting to reconcile them with the capabilities of biological intelligence have led me to realize that every process in Life itself, spanning molecular reaction to behavioral adaptation to genetic evolution, is nothing more than learning happening at many time scales. I’ll have much more to say on the topic of Artificial Life and Machine Learning in the future, but for now, thank you for reading this humble tutorial on fitting sinusoidal functions!

Acknowledgements

Thanks to Matthew Johnson for helping to proofread this post and helping me to resolve JAX questions.