Tuesday, February 5, 2019

Thoughts on the BagNet Paper

Some thoughts on the interesting BagNet paper (accepted at ICLR 2019) currently being circulated around the Machine Learning Twitter Community.

Disclaimer: I wasn't a reviewer of this paper for ICLR. I think it was worthy of acceptance to the conference, and hope it prompts further investigation by the research community. Please feel free to email me if you spot any mistakes / misunderstandings in this post.

Paper Summary:

Deep Convolutional Networks (CNNs) work by aggregating local features via learned convolutions followed by spatial pooling. Successive application of these "convolutional layers" results in a "hierarchy of features" that integrate low-level information across a wide spatial extent to form high-level information. 

As for algorithmic solutions, those aboard the deep learning hype train (myself included) believe that current deep CNNs perform global integration of information. There is a hand-wavy notion that intelligent visual understanding requires "seeing the forest for the trees." 

In the BagNet paper, the authors find that for the ImageNet classification task, the following algorithm (BagNet) works surprisingly well (86% Top-5 accuracy) in comparison to the deep AlexNet model (84.7% accuracy):

1) Chopping up the input images into 33x33 patches.
2) Running each patch through a deep net (1x1 convolutions) to get a class vector.
3) Add up the resulting class vectors spatially (across all patches). 
4) Prediction is the class with the most counts.



By way of analogy, it suggests that for image classification, you don't need a non-linear model to integrate a bunch of local features into a global representation, you just need to "count a bunch of trees to guess that it's a forest".

Some other experimental conclusions:
  • BagNet works slightly better when using 33x33 patches compared to 17x17 patches (80%). So deep nets do extract useful spatial information (9x9 vs. 17x17 vs. 33x33), just perhaps not to the global spatial extent we might have previously imagined (e.g. 112x112, 224x224).
  • Spatially distinct features from the BagNet model do not interact beyond the bagging step. This begs the question of whether most of the "power" of deep nets comes from merely examining local features. Are Deep Nets just BagNets? This would be quite concerning if that were the case! 
  • VGG appears to approximate BagNets quite well (though I am a bit skeptical about the author's methodology of showing this) while DenseNets and ResNets appear to be doing something totally different from BagNets (authors explain in the rebuttal that this may come from "(1) a more non-linear classifier on top of the local features or (2) larger local feature sizes".

Thoughts & Questions

Regardless of your beliefs on whether CNNs can/should take us all the way to Artificial General Intelligence or not, this paper offers a neat bit of evidence that we can build surprisingly powerful image classification models by only examining local features. It is often more helpful to tackle applied problems with a more interpretable model, and I'm glad to see such models doing surprisingly well for certain problems.

BagNet seems quite similar in principle to Generalized Additive Models, which predate Deep Learning quite a bit. The basic idea of GAMs to combine non-linear univariate features (i.e. $f(x_i)$ where each $x_i$ is a pixel and $f$ is a neural net) into a simple, interpretable features so that the marginal predictive distribution with respect to each variable can be interrogated. I'm particularly excited about ideas like Lou et al. which relax GAMs to support pairwise interactions between univariate feature extractors (2D marginals are still interpretable to humans).

The authors do not claim this explicitly, but it's easy to skim the paper quickly and think "DNNs suck; they are nothing more than BagNets". That's not actually the case (and the authors' experiments suggest this).

One counterexample: adversarial examples are clear instances where local modifications (sometimes a single pixel) can change global feature representations. So it is clear that global shape integration is happening for test inputs. The remaining question is whether global shape integration is happening where we think it should happen, and on which tasks this happens. As someone who is deeply interested in AGI, I find ImageNet much less interesting now, precisely because it can be solved with models that have little global understanding of images.

The authors also say this much themselves, that we need harder tasks that require global shape integration.



Generative modeling of images (e.g. GANs) is a task where it's quite clear that linear interactions between patch features are insufficient to model the unconditional joint distribution across pixels. Or consider my favorite RL task, Life on Earth, in which agents clearly need to perform spatial reasoning to solve problems like chasing prey and running away from predators. It would be fun to design an artificial life setup and see if organisms using bag-of-features perception can actually compete with organisms that use non-linear global integration (I doubt it).

If we train a model that should do better by integrating global information (i.e. classification), and it ends up just overfitting to local features, then this is a truly interesting result - it means that we need an optimization objective that does not allow models to cheat in this way. I think the "Life-on-Earth" is a great task for this, though I hope to find one that is computationally less resource intensive :)

Finally, a word on interpretability vs. causal inference. In the near term, I could see BagNet being useful for self-driving cars, where the parallelizability of considering each patch separately would give even better speedups for large images. Everyone wants ML models on self-driving cars to be interpretable, right? But there is also the psychological question of whether a human would prefer to get in a car that drives with a black box CNN that is "accurate, uninterpretable, and maybe wrong", or whether they want a car that makes decisions using Bag-of-Features: "accurate, interpretable, and definitely wrong". Lobbying for interpretability (as used by BagNet) seems to be at odds with demands for "causal inference" and "program induction" by means of achieving better generalizable machine learning, because a strong assumption of causal inference is that your model can express the true causal distribution. I'm curious how members of the community think we should reconcile this difference.

Update (Feb 9): There is a more positive way to look at these methods for better causal inference. Methods like BagNet can serve as a very useful sanity check when designing end-to-end systems (like robotics, self-driving cars): if your deep net is not performing much better than a system only examining local statistical regularities (like BagNet), it is a good sign that your model may still yet benefit from better global information integration. One might even consider jointly optimizing BagNet and Advantage(DeepNet, BagNet) so that the DeepNet must explicitly extract strictly better information than what BagNet does. I have been thinking of how to better verify our ML systems for robotics and building such "null hypothesis" models can be a good way to check that they aren't learning something silly.

No comments:

Post a Comment

Comments will be reviewed by administrator (to filter for spam and irrelevant content).