(A writeup of a side project that I’m working on. I’ve posted working code on Github for both pytorch and tensorflow.)

Imagine a deep neural network that has many chances to query information from its input, rather than just one chance, a network that reacts to the currently accumulated information to select its next query. How might this network work differently from today’s neural networks?

I think this type of deep network would be more opportunistic and symbiotic with the input. Rather than copying all the needed information from an input image into neural activations, the network would let the image continue to store the bulk of the information, and the network would perform a series of economical queries on the image. Neural representations would not store the answers to every conceivable question about the image. Instead, they would provide the minimal information needed to quickly answer questions via subsequent image queries, when that answer is needed.

The core benefit: Doing more with less

This type of network can do more with less. It iteratively gathers some information, then decides what information to gather next. Standard feedforward networks gather all the information they might need up front, and as a result a lot of processing is spent gathering information that does not end up being used.

This is especially evident on trivial toy data. Consider classifying which integer on a number line a point is nearest to, between \(1\) and \(N\). The neural network must output a unique binary representation for each integer, and it can only use weights, biases, and nonlinearities. (This may seem like a bad example of the type of task performed by a neural network, but it captures the core complexity of inferring the class of an observation generated by a mixture of high-dimensional Gaussians.) A standard neural network will need to perform \((N - 1)\) queries. An iterative network only needs to perform \(\log_2 n\).

Because the traditional static parallel approach can’t adapt its queries in response to the result of other queries, it has to do a lot more work.

This principle scales up to real-world data. Consider any task that involves observing an image. A sensible approach might involve first performing a low-resolution initial pass on the image, looking for regions of interest, followed by more detailed queries of those regions of interest. This will be more efficient and powerful than the non-reactive version, which would need to perform many different detailed queries on every part of the image. Yes, I may be underestimating the cleverness of a deep network, but I hope you can see the appeal of the idea.

Traditionally, Deep Learning emphasizes parallel distributed processing. This reactive approach tries to bring back some sequential (as opposed to distributed) processing, getting the best of both worlds.

How to perform a reactive query

There are many ways to make a neural network dynamically query its input. One way is via a form of attention that mimics a saccading vision system. I opted, instead, to use dynamically generated weights. Given a context vector, the model passes that vector into a normal neural network to output a set of weights, which are then run on the input image. This weight-generating network is called a hypernetwork. For image processing the model might generate the weights of a 4-layer convolutional neural network with a relatively small number of channels. This dynamically generated network is run on the image and outputs a new context vector, which can then be used to generate a new network, and so on. Context vectors are aggregated over time, either through residual connections or LSTM/GRU-like gates.

Hardware efficiency

For deep learning, it is prudent to use algorithms that support training in batches. If training in batches isn’t possible, you face an uphill battle. (This is part of why Transformers have overtaken recurrent neural networks for sequence processing.)

I was happy to find that this algorithm supports efficient training in batches, despite the fact that it departs from traditional deep networks by applying different weights to each input image. This operation can still be performed in parallel by implementing dynamic linear layers as batch matrix multiplications and dynamic convolutional layers as group convolutions.

This algorithm has a mix of performance costs and performance benefits that may balance out. On one hand, using different weights for every input image is more expensive. On the other hand, the dynamic approach enables each query to be significantly less complex, so each weight tensor and activation tensor is much smaller.

The neuroscience angle

This algorithm was partly inspired by a theory from neuroscience, and it provides a concrete illustration of what types of neural representations we might find in the brain.

Dana Ballard and colleagues have long written about how the the best economical solution to the inference problem changes entirely when you consider systems that are capable of actively responding to the world. They have criticized the classic Marr view that the human visual system builds up a representation of “what is where”. (This classic view overlaps heavily with some of my previous neuroscience work with Numenta.) Instead, Ballard and colleagues write:

the visual system is used to subserve problem-solving behaviors and such behaviors often do not require an accurate model of the world in the traditional sense of remembering positions of people and objects in the room

and

in a dynamic world, the cost of maintaining the correspondence between the representation and the world becomes prohibitive. For this reason animate vision systems may have to travel light and depend on highly adaptive behaviors that can quickly discover how to use current context

This algorithm uses a simple form of behavior – a dynamic query on the input image – and it “travels light” by using a small context vector whose job is to customize these dynamic queries and help them succeed in answering questions about the input image. This type of neural representation – a population whose job is to help the network figure out what actions it should take to answer subsequent questions – is an enticing model of what might be happening in the brain. Many thousands of careers have been spent trying to understand neural representations, especially in the hippocampus and mammalian neocortex. Some cortical representations have been partially explained, but most have been difficult to interpret. The neurons in this algorithm, similarly, would be hard to interpret, because they are used to complement the outside world and make it easy to query, rather than coding it directly.

Why it might not work

This adaptive query architecture might have too much potential to overfit. In terms of raw potential, traditional deep networks are inferior to these networks, but that inferiority may constrain them to only learn algorithms that successfully generalize outside the training set. By switching to an iterative reactive approach, we give the network potential to discover more powerful and economical algorithms, but gradient descent may be prone to finding solutions that don’t generalize. Maybe traditional deep architectures and traditional gradient descent are pair-bonded, and more advanced architectures will require slightly different learning algorithms.

Current status

The algorithm works reasonably well on toy data and MNIST (>99% accuracy, etc.) using fewer total convolutional FLOPS than a comparable static network. Preliminarily, it seems to work fine on CIFAR10, but I haven’t pushed it very hard. I’m still early in the tweaking / hyperparameter tuning process. Feel free to take a look at the code and run it yourself.

(Thanks to Eric Frank for helpful discussions on hypernetworks. Also, I decided to share this unfinished project after reading a Twitter thread from Rosanne Liu.)