Sunday, August 7, 2016

A Beginner's Guide to Variational Methods: Mean-Field Approximation

Variational Bayeisan (VB) Methods are a family of techniques that are very popular in statistical Machine Learning. VB methods allow us to re-write statistical inference problems (i.e. infer the value of a random variable given the value of another random variable) as optimization problems (i.e. find the parameter values that minimize some objective function).

This inference-optimization duality is powerful because it allows us to use the latest-and-greatest optimization algorithms to solve statistical Machine Learning problems (and vice versa, minimize functions using statistical techniques).

This post is an introductory tutorial on Variational Methods. I will derive the optimization objective for the simplest of VB methods, known as the Mean-Field Approximation. This objective, also known as the Variational Lower Bound, is exactly the same one used in Variational Autoencoders (a neat paper which I will explain in a follow-up post).

Table of Contents

  1. Preliminaries and Notation
  2. Problem formulation
  3. Variational Lower Bound for Mean-field Approximation
  4. Forward KL vs. Reverse KL
  5. Connections to Deep Learning

Preliminaries and Notation


This article assumes that the reader is familiar with concepts like random variables, probability distributions, and expectations. Here's a refresher if you forgot some stuff. Machine Learning & Statistics notation isn't standardized very well, so it's helpful to be really precise with notation in this post:
  • Uppercase $X$ denotes a random variable
  • Uppercase $P(X)$ denotes the probability distribution over that variable
  • Lowercase $x \sim P(X)$ denotes a value $x$ sampled ($\sim$) from the probability distribution $P(X)$ via some generative process.
  • Lowercase $p(X)$ is the density function of the distribution of $X$. It is a scalar function over the measure space of $X$.
  • $p(X=x)$ (shorthand $p(x)$) denotes the density function evaluated at a particular value $x$. 

Many academic papers use the terms "variables", "distributions", "densities", and even "models" interchangeably. This is not necessarily wrong per se, since $X$, $P(X)$, and $p(X)$ all imply each other via a one-to-one correspondence. However, it's confusing to mix these words together because their types are different (it doesn't make sense to sample a function, nor does it make sense to integrate a distribution). 

We model systems as a collection of random variables, where some variables ($X$) are "observable", while other variables ($Z$) are "hidden". We can draw this relationship via the following graph:

 


The edge drawn from $Z$ to $X$ relates the two variables together via the conditional distribution $P(X|Z)$.

Here's a more concrete example: $X$ might represent the "raw pixel values of an image", while $Z$ is a binary variable such that $Z=1$ "if $X$ is an image of a cat".

$X = $ 
$P(Z=1) = 1$ (definitely a cat)

$X= $
$P(Z=1) = 0$ (definitely not a cat)


$X = $ 
$P(Z=1) = 0.1$ (sort of cat-like)

Bayes' Theorem gives us a general relationship between any pair of random variables:

$$p(Z|X) = \frac{p(X|Z)p(Z)}{p(X)}$$

The various pieces of this are associated with common names:

$p(Z|X)$ is the posterior probability: "given the image, what is the probability that this is of a cat?" If we can sample from $z \sim P(Z|X)$, we can use this to make a cat classifier that tells us whether a given image is a cat or not.

$p(X|Z)$ is the likelihood: "given a value of $Z$  this computes how "probable" this image $X$ is under that category ({"is-a-cat" / "is-not-a-cat"}). If we can sample from $x \sim P(X|Z)$, then we generate images of cats and images of non-cats just as easily as we can generate random numbers. If you'd like to learn more about this, see my other articles on generative models: [1], [2].

$p(Z)$ is the prior probability. This captures any prior information we know about $Z$ - for example, if we think that 1/3 of all images in existence are of cats, then $p(Z=1) = \frac{1}{3}$ and $p(Z=0) = \frac{2}{3}$.

Hidden Variables as Priors


This is an aside for interested readers. Skip to the next section to continue with the tutorial.

The previous cat example presents a very conventional example of observed variables, hidden variables, and priors. However, it's important to realize that the distinction between hidden / observed variables is somewhat arbitrary, and you're free to factor the graphical model however you like.

We can re-write Bayes' Theorem by swapping the terms:

$$\frac{p(Z|X)p(X)}{p(Z)} = p(X|Z)$$

The "posterior" in question is now $P(X|Z)$.

Hidden variables can be interpreted from a Bayesian Statistics framework as prior beliefs attached to the observed variables. For example, if we believe $X$ is a multivariate Gaussian, the hidden variable $Z$ might represent the mean and variance of the Gaussian distribution. The distribution over parameters $P(Z)$ is then a prior distribution to $P(X)$.

You are also free to choose which values $X$ and $Z$ represent. For example, $Z$ could instead be "mean, cube root of variance, and $X+Y$ where $Y \sim \mathcal{N}(0,1)$". This is somewhat unnatural and weird, but the structure is still valid, as long as $P(X|Z)$ is modified accordingly.

You can even "add" variables to your system. The prior itself might be dependent on other random variables via $P(Z|\theta)$, which have prior distributions of their own $P(\theta)$, and those have priors still, and so on. Any hyper-parameter can be thought of as a prior. In Bayesian statistics, it's priors all the way down.



Problem Formulation


The key problem we are interested in is posterior inference, or computing functions on the hidden variable $Z$. Some canonical examples of posterior inference:
  • Given this surveillance footage $X$, did the suspect show up in it?
  • Given this twitter feed $X$, is the author depressed?
  • Given historical stock prices $X_{1:t-1}$, what will $X_t$ be?

We usually assume that we know how to compute functions on likelihood function $P(X|Z)$ and priors $P(Z)$.

The problem is, for complicated tasks like above, we often don't know how to sample from $P(Z|X)$ or compute $p(X|Z)$. Alternatively, we might know the form of $p(Z|X)$, but the corresponding computation is so complicated that we cannot evaluate it in a reasonable amount of time. We could try to use sampling-based approaches like MCMC, but these are slow to converge.

Variational Lower Bound for Mean-field Approximation


The idea behind variational inference is this: let's just perform inference on an easy, parametric distribution $Q_\phi(Z|X)$ (like a Gaussian) for which we know how to do posterior inference, but adjust the parameters $\phi$ so that $Q_\phi$ is as close to $P$ as possible.

This is visually illustrated below: the blue curve is the true posterior distribution, and the green distribution is the variational approximation (Gaussian) that we fit to the blue density via optimization.



What does it mean for distributions to be "close"? Mean-field variational Bayes (the most common type) uses the Reverse KL Divergence to as the distance metric between two distributions.

$$KL(Q_\phi(Z|X)||P(Z|X)) = \sum_{z \in Z}{q_\phi(z|x)\log\frac{q_\phi(z|x)}{p(z|x)}}$$

Reverse KL divergence measures the amount of information (in nats, or units of $\frac{1}{\log(2)}$ bits) required to "distort" $P(Z)$ into $Q_\phi(Z)$. We wish to minimize this quantity with respect to $\phi$.

By definition of a conditional distribution, $p(z|x) = \frac{p(x,z)}{p(x)}$. Let's substitute this expression into our original $KL$ expression, and then distribute:

$$
\begin{align}
KL(Q||P) & = \sum_{z \in Z}{q_\phi(z|x)\log\frac{q_\phi(z|x)p(x)}{p(z,x)}} && \text{(1)} \\
& = \sum_{z \in Z}{q_\phi(z|x)\big(\log{\frac{q_\phi(z|x)}{p(z,x)}} + \log{p(x)}\big)} \\
& = \Big(\sum_{z}{q_\phi(z|x)\log{\frac{q_\phi(z|x)}{p(z,x)}}}\Big) + \Big(\sum_{z}{\log{p(x)}q_\phi(z|x)}\Big) \\
& = \Big(\sum_{z}{q_\phi(z|x)\log{\frac{q_\phi(z|x)}{p(z,x)}}}\Big) + \Big(\log{p(x)}\sum_{z}{q_\phi(z|x)}\Big) && \text{note: $\sum_{z}{q(z)} = 1 $} \\
& = \log{p(x)} + \Big(\sum_{z}{q_\phi(z|x)\log{\frac{q_\phi(z|x)}{p(z,x)}}}\Big)  \\
\end{align}
$$

To minimize $KL(Q||P)$ with respect to variational parameters $\phi$, we just have to minimize $\sum_{z}{q_\phi(z|x)\log{\frac{q_\phi(z|x)}{p(z,x)}}}$, since $\log{p(x)}$ is fixed with respect to $\phi$. Let's re-write this quantity as an expectation over the distribution $Q_\phi(Z|X)$.

$$
\begin{align}
\sum_{z}{q_\phi(z|x)\log{\frac{q_\phi(z|x)}{p(z,x)}}} & = \mathbb{E}_{z \sim Q_\phi(Z|X)}\big[\log{\frac{q_\phi(z|x)}{p(z,x)}}\big]\\
& = \mathbb{E}_Q\big[ \log{q_\phi(z|x)} - \log{p(x,z)} \big] \\
& = \mathbb{E}_Q\big[ \log{q_\phi(z|x)} - (\log{p(x|z)} + \log(p(z))) \big] && \text{(via  $\log{p(x,z)=p(x|z)p(z)}$) }\\
& = \mathbb{E}_Q\big[ \log{q_\phi(z|x)} - \log{p(x|z)} - \log(p(z))) \big] \\
\end{align} \\
$$

Minimizing this is equivalent to maximizing the negation of this function:

$$
\begin{align}
\text{maximize } \mathcal{L} & = -\sum_{z}{q_\phi(z|x)\log{\frac{q_\phi(z|x)}{p(z,x)}}} \\
& = \mathbb{E}_Q\big[ -\log{q_\phi(z|x)} + \log{p(x|z)} + \log(p(z))) \big] \\
& =  \mathbb{E}_Q\big[ \log{p(x|z)} + \log{\frac{p(z)}{ q_\phi(z|x)}} \big] && \text{(2)} \\
\end{align}
 $$

In literature, $\mathcal{L}$ is known as the variational lower bound, and is computationally tractable if we can evaluate $p(x|z), p(z), q(z|x)$. We can further re-arrange terms in a way that yields an intuitive formula:

$$
\begin{align*}
\mathcal{L} & =  \mathbb{E}_Q\big[ \log{p(x|z)} + \log{\frac{p(z)}{ q_\phi(z|x)}} \big] \\
& =   \mathbb{E}_Q\big[ \log{p(x|z)} \big] + \sum_{Q}{q(z|x)\log{\frac{p(z)}{ q_\phi(z|x)}}} && \text{Definition of expectation} \\
& =  \mathbb{E}_Q\big[ \log{p(x|z)} \big] - KL(Q(Z|X)||P(Z)) && \text{Definition of KL divergence} && \text{(3)}
\end{align*}
$$

If sampling $z \sim Q(Z|X)$ is an "encoding" process that converts an observation $x$ to latent code $z$, then sampling $x \sim Q(X|Z)$ is a "decoding" process that reconstructs the observation from $z$.

It follows that $\mathcal{L}$ is the sum of the expected "decoding" likelihood (how good our variational distribution can decode a sample of $Z$ back to a sample of $X$), plus the KL divergence between the variational approximation and the prior on $Z$. If we assume $Q(Z|X)$ is conditionally Gaussian, then prior $Z$ is often chosen to be a diagonal Gaussian distribution with mean 0 and standard deviation 1.

Why is $\mathcal{L}$ called the variational lower bound? Substituting $\mathcal{L}$ back into Eq. (1), we have:

$$
\begin{align*}
KL(Q||P) & = \log p(x) - \mathcal{L} \\
\log p(x) & = \mathcal{L} + KL(Q||P) && \text{(4)}
\end{align*}
$$

The meaning of Eq. (4), in plain language, is that $p(x)$, the log-likelihood of a data point $x$ under the true distribution, is $\mathcal{L}$, plus an error term $KL(Q||P)$ that captures the distance between $Q(Z|X=x)$ and $P(Z|X=x)$ at that particular value of $X$.

Since $KL(Q||P) \geq 0$, $\log p(x)$ must be greater than $\mathcal{L}$. Therefore $\mathcal{L}$ is a lower bound for $\log p(x)$. $\mathcal{L}$ is also referred to as evidence lower bound (ELBO), via the alternate formulation:

$$
\mathcal{L} = \log p(x) - KL(Q(Z|X)||P(Z|X)) = \mathbb{E}_Q\big[ \log{p(x|z)} \big] - KL(Q(Z|X)||P(Z))
$$

Note that $\mathcal{L}$ itself contains a KL divergence term between the approximate posterior and the prior, so there are two KL terms in total in $\log p(x)$.

Forward KL vs. Reverse KL


KL divergence is not a symmetric distance function, i.e. $KL(P||Q) \neq KL(Q||P)$ (except when $Q \equiv P$) The first is known as the "forward KL", while the latter is "reverse KL". So why do we use Reverse KL? This is because the resulting derivation would require us to know how to compute $p(Z|X)$, which is what we'd like to do in the first place.

I really like Kevin Murphy's explanation in the PML textbook, which I shall attempt to re-phrase here:

Let's consider the forward-KL first. As we saw from the above derivations, we can write KL as the expectation of a "penalty" function $\log \frac{p(z)}{q(z)}$ over a weighing function $p(z)$.

$$
\begin{align*}
KL(P||Q) & = \sum_z p(z) \log \frac{p(z)}{q(z)} \\
& = \mathbb{E}_{p(z)}{\big[\log \frac{p(z)}{q(z)}\big]}\\
\end{align*}
$$

The penalty function contributes loss to the total KL wherever $p(Z) > 0$. For $p(Z) > 0$, $\lim_{q(Z) \to 0} \log \frac{p(z)}{q(z)} \to \infty$. This means that the forward-KL will be large wherever $Q(Z)$ fails to "cover up" $P(Z)$.

Therefore, the forward-KL is minimized when we ensure that $q(z) > 0$ wherever $p(z)> 0$. The optimized variational distribution $Q(Z)$ is known as "zero-avoiding" (density avoids zero when $p(Z)$ is zero).



Minimizing the Reverse-KL has exactly the opposite behavior:

$$
\begin{align*}
KL(Q||P) & = \sum_z q(z) \log \frac{q(z)}{p(z)} \\
& = \mathbb{E}_{p(z)}{\big[\log \frac{q(z)}{p(z)}\big]}
\end{align*}
$$

If $p(Z) = 0$, we must ensure that the weighting function $q(Z) = 0$ wherever denominator $p(Z) = 0$, otherwise the KL blows up. This is known as "zero-forcing":


So in summary, minimizing forward-KL "stretches" your variational distribution $Q(Z)$ to cover over the entire $P(Z)$ like a tarp, while minimizing reverse-KL "squeezes" the $Q(Z)$ under $P(Z)$.

It's important to keep in mind the implications of using reverse-KL when using the mean-field approximation in machine learning problems. If we are fitting a unimodal distribution to a multi-modal one, we'll end up with more false negatives (there is actually probability mass in $P(Z)$ where we think there is none in $Q(Z)$).

Connections to Deep Learning


Variational methods are really important for Deep Learning. I will elaborate more in a later post, but here's a quick spoiler:
  1. Deep learning is really good at optimization (specifically, gradient descent) over very large parameter spaces using lots of data.
  2. Variational Bayes give us a framework with which we can re-write statistical inference problems as optimization problems.
Combining Deep learning and VB Methods allow us to perform inference on extremely complex posterior distributions. As it turns out, modern techniques like Variational Autoencoders optimize the exact same mean-field variational lower-bound derived in this post!

Thanks for reading, and stay tuned!