Sunday, March 10, 2019

What I Cannot Control, I Do not Understand

Xiaoyi Yin has graciously translated this blog post to 中文.

I often hear the remark around the proverbial AI watering hole that there are no examples of reinforcement learning (RL) deployed in commercial settings that couldn’t be replaced by simpler algorithms.

This is somewhat true. If one takes RL to mean “neural networks trained with DQN / PPO / Soft-Actor Critic etc.”, then indeed, there are no commercial products (yet!) whose success relies on Deep RL algorithmic breakthroughs in the last 5 years [1].

However, if one interprets “reinforcement learning” to mean the notion of “learning from repeated trial and error”, then commercial applications abound, especially in pharmaceuticals, finance, TV show recommendations, and other endeavors based on scientific experimentation and intervention.

I’ll explain in this post how Reinforcement Learning is a general approach to solving the Causal Inference problem, the desiderata of nearly all machine learning systems. In this sense, many high-impact problems are already tackled using ideas from RL, but under different terminology and engineering processes.

Doctor, Won’t You Help Me Live Longer

Let’s suppose you are a doctor tasked with helping your patients live longer. You know a thing or two about data science, so you fit a model on a lot of patient records to predict life expectancy, and make a shocking finding: people who drink red wine every day have a 90% likelihood of living over 80 years, compared to the base probability of 50% for non drinkers.

In the parlance of causal inference, you’ve found the following observational distribution:

p(patient lives > 80 yrs | patient drinks red wine daily) = .9

Furthermore, your model has high accuracy on holdout datasets, which increases your confidence that your model has discovered the secret to longevity. Elated, you start telling your patients to drink red wine daily. After all, as a doctor, it is insufficient to predict; we must also prescribe! And what’s not to like about living longer and drinking red wine on the daily?

Many decades later, you follow up on your patients and -- with great disappointment -- observe the following interventional distribution:

p(patient lives > 80 yrs | do(patient drinks red wine daily)) = .5

The life expectancy of patients on the red wine has not increased! What gives?

Finding the Causal Model

The core problem here lies in confounding variables. When we decided to prescribe red wine to patients based on the observational model, we made a strong hypothesis about the causality diagram:

The directed edges between these random variables here denote causality, which can also be thought of as "the arrow of time". Changing the value of the “Drinks Red Wine” variable ought to have an effect on “Live > 80 years”, but changing “Lives > 80 years” has no effect on drinking red wine.

If this causal diagram was correct, then our intervention should have increased the lifespan of patients. But the actual experiment does not support this, so we must reject this hypothetical causal model and reach for alternative hypotheses to explain the data. Perhaps there are one or more variables that cause a higher propensity of red wine drinking, AND living longer, thus correlating those variables together?

We make the educated guess that a confounding variable might be that wealthy people tend to simultaneously live longer and drink more wine. Combing through the data again, we find that P(drinks red wine | is wealthy) = 0.9 and P(lives > 80 | is wealthy) = 1.0. So our hypothesis now takes the form:

If our understanding of the world is correct, then do(is wealthy) should make people live > 80 years and drink more red wine. And indeed, we find that once we give patients $1M cash infusions to make them wealthy (by USA standards), they end up living longer and drinking red wine daily (this is a hypothetical result, fabricated for the sake of this blog post).

RL as Automated Causal Inference

ML models are increasingly used to drive decision making in recommender systems, self-driving cars, pharmaceutical R&D, and experimental physics. In many cases, we desire an outcome event $y$, for which we attempt to learn a model $p(y|x_1, .., x_N)$ and then choose inputs $x_1...x_N$ to maximize $p(y|x_1...x_N)$.

It should be quite obvious from the previous medical example that to avoid causality when building decision-making systems is to risk overfitting models that are not useful for prescribing intervention. Suppose we automated the causal model discovery process in the following manner:
  1. Fit an observational model to the data p(y|x_1, x_2, … x_N)
  2. Assume the observational model captures the causal model. Prescribe an intervention do(x_i) that maximizes p(y|x_1..N) and gather a new dataset where 50% of x_i has the intervention and 50% does not.
  3. Fit an observational model to the new data p(y|x_i)
  4. Repeat steps 1-3 until observational model matches intervention model: p(y|do(x_i)) = p(y|x_i)
To return to the red wine case study as a test case:
  1. You would initially have p(live > 80 years | drink red wine daily) = .9. 
  2. Upon gathering a new dataset, you would obtain p(live > 80 years | do(drink red wine daily)) = .5. Model is not converged, but at least your observational model no longer believes that drinking red wine explains living longer. Furthermore, it now pays attention to the right variable, that p(live > 80 years | is_wealthy) = 1.
  3. The subsequent iteration of this procedure then finds that p(live > 80 years | do(is wealthy)) = 1, so we are done.

The act of gathering a randomized trial (the 50% split of intervention vs. non-intervention) and re-training a new observational model is one of the most powerful ways to do general causal inference, because it uses data from reality (which “knows” the true causal model) to stamp out incorrect hypotheses.

Repeatedly training observational models and suggesting interventions is what RL algorithms are all about, which is solving optimal control for sequential decision-making problems. Control is the operative word here - the true test of whether an agent understands its environment is whether it can solve it.

For ML models whose predictions are used to infer interventions (so as to manipulate some downstream random variable), I argue that the overfitting problem is nothing more than a causal inference problem. This also explains why RL tends to be much harder as a machine learning problem than supervised learning - not only are there fewer bits of supervision per observation, but the RL agent must also figure out the causal, interventionist distribution required to behave optimally.

One salient case of “overfitting” is in RL algorithms can theoretically be trained “offline” -- that is, learning entirely from off-policy data without gathering new data samples from the environment. However, without periodically gathering new experience from the environment, agents can overfit to finite-size datasets or dataset imbalances, and propose interventions that do not generalize past their offline data. The best way to check if an agent is “learning the right thing” is to deploy it in the world and verify its hypotheses under the interventionist distribution. Indeed, for our robotic grasping research at Google, we often find that fine-tuning with “online” experience improves performance substantially. This is equivalent to re-training an observational model on new data p(grasp success | do(optimal_action)).

Production "RL"

The A/B testing framework often used in production engineering is a manual version of the "automated causal inference" pipeline, where a random 50% of users (assumed to be identically distributed) are shown one intervention and the other 50% are shown the control.

This is the cornerstone of data-driven decision making, and is used widely at hedge funds, Netflix, StitchFix, Google, Walmart, and so on. Although this process has humans in the loop (specifically for proposing interventions and choosing the stopping criterion), there are many related nuances to these methodologies that also arise in RL literature like data non-stationarity, the difficulty of obtaining truly randomized experiments, and long-term credit assignment. I’m just starting to learn about causal inference myself, and hope that in the next few years there will be more cross-fertilization of ideas between the RL, Data Science, and Causal Inference research communities.

For a more technical introduction to Causal Inference, see this great blog series by Ferenc Huszar.

[1] A footnote on why I think RL hasn’t had much commercial deployment yet. Feel free to clue me in if there are indeed companies using RL in production that I don’t know about!

In order for a company to be justified in adopting RL technology, the problem at hand needs to be 1) commercially useful 2) feasible for current Deep RL algorithms 3) the marginal utility of optimal control must be worth the technical risks of Deep RL.

Let’s consider deep image understanding by comparison: 1) everything from surveillance to self-driving cars to FaceID is highly commercially interesting 2) current models are highly accurate and scale well to a variety of image datasets 3) the models generally work as expected and do not require great expertise to train and deploy.

As for RL, it doesn’t take a great imagination to realize that general RL algorithms would eventually enable robots to learn skills entirely on their own, or help companies make complex financial decisions like stock buybacks and hiring, or enable far richer NPC behavior in games. Unfortunately, these problem domains don’t meet criteria (2) - the technology simply isn’t ready and requires many more years of R&D.

For problems where RL is plausible, it is difficult to justify being the first user of a technology whose marginal utility to your problem of choice is unproven. Example problems might include datacenter cooling or air traffic control. Even for domains where RL has been shown clearly to work (e.g. low-dimensional control or pixel-level control), RL still requires a lot of research skill to build a working system.