Attention is all you need. (aka the Transformer network)

No matter how we frame it, in the end, studying the brain is equivalent to trying to predict one sequence from another sequence. We want to predict complicated movements from neural activity. We want to predict neural activity from time-varying stimuli. We even want to predict neural activity in one brain area from neural activity in another area. Thus, as computational neuroscientists, we should be intimately familiar with new machine learning techniques that allow us to better relate one sequence to another.

One sequence-to-sequence problem receiving a lot of interest in machine learning is “translation”—converting a sentence in one language (e.g., English) to another (e.g., Polish). The main challenge is getting context correct. For example, consider “bark” for the following sentences: “The bark is loud.” and “The bark is brown.” In Polish, bark would be “szczekanie” and “kora,” respectively, based on context. Machine learning folks have devised some ingenious architectures to tackle this problem. In this post, I’ll talk about one architecture that is unexpected but seems to work quite well: Transformer networks. This architecture was proposed in “Attention is all you need.” by Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, Polosukhin, NIPS 2017.

High-level intuition

The Transformer network relies on an encoding-decoding approach. The idea is to “encode” the input sequence into a latent representation (i.e., an embedding) that is easier to work with. This embedding (typically a vector of abstract variables) is then “decoded” into an output sequence (e.g., another language). Our problem, then, is to design a “useful” embedding for translation, and build an architecture that can express it.

What properties do we want our embedding to have? One property is a measure of similarity— “tomato” is more similar to “blueberry” than to “chair.” An embedding with this property has been solved with word2vec. Another important property is context—is “tomato” the subject? direct object? is it modified by “tasty” or “rotten”? It turns out we can incorporate context by adding embeddings together. For example, let embeddings e1=”tomato” and e2=”tasty”. Then
e3 = e1 + e2 corresponds to an embedding of a “tasty tomato”. We can keep tacking on more context (e.g., e4 = “subject of sentence”, e5 = “the”, e6 = “belongs to Tom”, etc.). With this embedding property, if we come across an ambiguous word like “it”, we simply find the word that “it” refers to, and add that word to the embedding!

Architecture: Stacked attention layers

The authors came up with a way to efficiently search for context and add that context to the embedding. They use “attention”—in the general sense of the word, not based on neuroscience—as a way to find distant relationships between words. Figure 1 is a stylized version of this computation. The idea is, given a word, compare this word with all other words in the sentence for a specific context. For example, if the word is “sees”, attention will try to find “who sees?” and “sees what?” If attention finds answers to its context questions, these answers will be incorporated into the embedding via a linear combination. If attention does not receive answers, no harm is done—attention will simply pass on the unmodified word embedding (i.e., weights for the other words will be zero).

Figure 1: The attention mechanism. (this is my first time making gifs—sorry for the blurriness!)

One can imagine that having one level of attention may not be the best way to extract the structure of natural language. Just like we learned in sixth grade, sentences have hierarchical structure (did you also have to draw those syntax trees?). So, we should stack levels of attention. The lower levels correspond to easy questions (“Is there an adjective to this noun?”) that likely involve only two or three words of the input sequence. The deeper levels correspond to more nuanced questions (“What is the direct object of this sentence?”) that span most if not all of the input sequence. To achieve this, the authors stacked layers of attention on top of each other. To decode, they basically reverse the process. Figure 2 is an illustration of this process. 

Figure 2: Attention blocks form an encoder stack, which feeds into a decoder stack.

The authors claim this approach minimizes the distance of relating one position in a sequence to another. This is because attention looks at all pairwise interactions, while other architectures (such as recurrent neural networks) sequentially look at the input sequence (making it difficult to compare the first word to the last word). The primary point of confusion while reading the paper was how input sequences were fed into the network (hopefully Fig. 2 clarifies this—each word embedding is transformed in parallel to other word embeddings. If you input N words, each encoder will output N embeddings). This architecture produced state-of-the-art results, and has since been used in many different natural language processing tasks. It also uses many bells and whistles (residual blocks, layer normalization, …)—it will be interesting to see future work argue which components are the most important for this architecture.

Getting back to the neuroscience

As computational neuroscientists, we should take advantage of these architectures (which have been designed and trained with the aid of GPU armies), and use them to help solve our own problems. One aspect of this work (and in natural language processing in general) is the idea of adding two embedding vectors to form a more “context-relevant” embedding. This is missing from our current latent variable models. More importantly, it may be a way in which the brain also encodes context. I also think this type of attention computation will be useful when we are trying to predict sequences of natural behavior, where we cannot make use of task structure (e.g., a delay period, stimulus onset, etc.). Experimentalists are now collecting massive datasets of natural behavior—a perfect opportunity to see how this attention computation holds up!