Saturday, February 13, 2021

Don't Mess with Backprop: Doubts about Biologically Plausible Deep Learning

Biologically Plausible Deep Learning (BPDL) is an active research field at the intersection of Neuroscience and Machine Learning, studying how we can train deep neural networks with a "learning rule" that could conceivably be implemented in the brain.

The line of reasoning that typically motivates BPDL is as follows:

  1. A Deep Neural Network (DNN) can learn to perform perception tasks that biological brains are capable of (such as detecting and recognizing objects).
  2. If activation units and their weights are to DNNs as what neurons and synapses are to biological brains, then what is backprop (the primary method for training deep neural nets) analogous to?
  3. If learning rules in brains are not implemented using backprop, then how are they implemented? How can we achieve similar performance to backprop-based update rules while still respecting biological constraints?

A nice overview of the ways in which backprop is not biologically plausible can be found here, along with various algorithms that propose fixes.

My somewhat contrarian opinion is that designing biologically plausible alternatives to backprop is the wrong question to be asking. The motivating premises of BPDL makes a faulty assumption: that layer activations are neurons and weights are synapses, and therefore learning-via-backprop must have a counterpart or alternative in biological learning.

Despite the name and their impressive capabilities on various tasks, DNNs actually have very little to do with biological neural networks. One of the great errors in the field of Machine Learning is that we ascribe too much biological  meaning to our statistical tools and optimal control algorithms. It leads to confusion from newcomers, who ascribe entirely different meaning to "learning", "evolutionary algorithms", and so on.

DNNs are a sequence of linear operations interspersed with nonlinear operations, applied sequentially to real-valued inputs - nothing more. They are optimized via gradient descent, and gradients are computed efficiently using a dynamic programming scheme known as backprop. Note that I didn't use the word "learning"!

Dynamic programming is the ninth wonder of the world1, and in my opinion one of the top three achievements of Computer Science. Backprop has linear time-complexity in network depth, which makes it extraordinarily hard to beat from a computational cost perspective. Many BPDL algorithms often don't do better than backprop, because they try to take an efficient optimization scheme and shoehorn in an update mechanism with additional constraints. 

If the goal is to build a biologically plausible learning mechanism, there's no reason that units in Deep Neural Networks should be one-to-one with biological neurons. Trying to emulate a DNN with models of biologically neurons feels backwards; like trying to emulate the Windows OS with a human brain. It's hard and a human brain can't simulate Windows well.

Instead, let's do the emulation the other way around: optimizing a function approximator to implement a biologically plausible learning rule. The recipe is straightforward:

  1. Build a biological plausible model of a neural network with model neurons and synaptic connections. Neurons communicate with each other using spike trains, rate coding, or gradients, and respect whatever constraints you deem to be "sufficiently biologically plausible". It has parameters that need to be trained.
  2. Use computer-aided search to design a biologically plausible learning rule for these model neurons. For instance, each neuron's feedforward behavior and local update rules can be modeled as a decision from an artificial neural network.
  3. Update the function approximator so that the biological model produces the desired learning behavior. We could train the neural networks via backprop. 

The choice of function approximator we use to find our learning rule is irrelevant - what we care about at the end of the day is answering how a biological brain is able to learn hard tasks like perception, while respecting known constraints like the fact that biological neurons don't store all activations in memory or only employ local learning rules. We should leverage Deep Learning's ability to find good function approximators, and direct that towards finding a good biological learning rules.

The insight that we should (artificially) learn to (biologically) learn is not a new idea, but it is one that I think is not yet obvious to the neuroscience + AI community. Meta-Learning, or "Learning to Learn", is a field that has emerged in recent years, which formulates the act of acquiring a system capable of performing learning behavior (potentially superior to gradient descent). If meta-learning can find us more sample efficient or superior or robust learners, why can't it find us rules that respect biological learning constraints? Indeed, recent work [1, 2, 3, 4, 5] shows this to be the case. You can indeed use backprop to train a separate learning rule superior to na├»ve backprop.

I think the reason that many researchers have not really caught onto this idea (that we should emulate biologically plausible circuits with a meta-learning approach) is that until recently, compute power wasn't quite strong enough to both train a meta-learner and a learner. It still requires substantial computing power and research infrastructure to set up a meta-optimization scheme, but tools like JAX make it considerably easier now.

A true biology purist might argue that finding a learning rule using gradient descent and backprop is not an "evolutionarily plausible learning rule", because evolution clearly lacks the ability to perform dynamic programming or even gradient computation. But this can be amended by making the meta-learner evolutionarily plausible. For instance, the mechanism with which we select good function approximators does not need rely on backprop at all. Alternatively, we could formulate a meta-meta problem whereby the selection process itself obeys rules of evolutionary selection, but the selection process is found using, once again, backprop.

Don't mess with backprop!


[1] The eighth wonder being, of course, compound interest.

Monday, January 25, 2021

How to Understand ML Papers Quickly

My ML mentees often ask me some variant of the question "how do you choose which papers to read from the deluge of publications flooding Arxiv every day?” 

The nice thing about reading most ML papers is that you can cut through the jargon by asking just five simple questions. I try to answer these questions as quickly as I can when skimming papers.

1) What are the inputs to the function approximator?

E.g. a 224x224x3 RGB image with a single object roughly centered in the view. 

2) What are the outputs to the function approximator?

E.g. a 1000-long vector corresponding to the class of the input image.

Thinking about inputs and outputs to the system in a method-agnostic way lets you take a step back from the algorithmic jargon and consider whether other fields have developed methods that might work here using different terminology. I find this approach especially useful when reading Meta-Learning papers

By thinking about a ML problem first as a set of inputs and desired outputs, you can reason whether the input is even sufficient to predict the output. Without this exercise you might accidentally set up a ML problem where the output can't possibly be determined by the inputs. The result might be a ML system that performs predictions in a way that are problematic for society

3) What loss supervises the output predictions? What assumptions about the world does this particular objective make?

ML models are formed from combining biases and data. Sometimes the biases are strong, other times they are weak. To make a model generalize better, you need to add more biases or add more unbiased data. There is no free lunch

An example: many optimal control algorithms make the assumption of a stationary episodic data generation procedure which is a Markov-Decision Process (MDP). In an MDP, “state” and “action” deterministically map via the environment’s transition dynamics to “a next-state, reward, and whether the episode is over or not”. This structure, though very general, can be used to formulate a loss that allows learning Q values to follow the Bellman Equation.

4) Once trained, what is the model able to generalize to, in regards to input/output pairs it hasn’t seen before?

Due to the information captured in the data or the architecture of the model, the ML system may generalize fairly well to inputs it has never seen before. In recent years we are seeing more and more ambitious levels of generalization, so when reading papers I watch out to see any surprising generalization capabilities and where it comes from (data, bias, or both). 

There is a lot of noise in the field about better inductive biases, like causal reasoning or symbolic methods or object-centric representations. These are important tools for building robust and reliable ML systems and I get that the line separating structured data vs. model biases can be blurry. That being said, it baffles me how many researchers think that the way to move ML forward is to reduce the amount of learning and increase the amount of hard-coded behavior. 

We do ML precisely because there are things we don't know how to hard-code. As Machine Learning researchers, we should focus our work on making learning methods better, and leave the hard-coding and symbolic methods to the Machine Hard-Coding Researchers. 

5) Are the claims in the paper falsifiable? 

Papers that make claims that cannot be falsified are not within the realm of science. 

P.S. for additional hot takes and mentorship for aspiring ML researchers, sign up for my free office hours. I've been mentoring students over Google Video Chat most weekends for 7 months now and it's going great.