Perceiver: General Perception with Iterative Attention

Perceiver: General Perception with Iterative Attention

Written by Michael (Mike) Erlihson, PhD.

This review is part of a series of reviews in Machine & Deep Learning that are originally published in Hebrew, aiming to make it accessible in a plain language under the name #DeepNightLearners.

Good night friends, today we are again in our section DeepNightLearners with a review of a Deep Learning article. Today I've chosen to review the article Perceiver: General Perception with Iterative Attention.

Reviewer Corner:

Reading recommendation: a must (!!!) if you're into Transformers. For others - very recommended (the idea behind it is very cool!).
Clarity of writing: Medium plus
Required math and ML/DL knowledge level to understand the article: Basic knowledge with the Transformer architecture and with software complexity.
Practical applications: Transformers with low complexity which can process long series of data (image patches, video frames, long text, etc.).

Article Details:

Article link: available for download here
Code link: available here, here, and here
Published on: 04/03/2021 on Arxiv
Presented at: Unknown

Article Domains:

  • Transformers with low complexity and low storage

Mathematical Tools, Concepts and Marks:

  • Transformers Architecture Basics

Introduction:

The transformer is a neural network architecture designed to process serial data. It was first introduced in 2017 in an article titled Attention is All You Need. Since then, the Transformers took over the NLP world and became its default architecture. Through pre-training, the Transformers are used to build meaningful data representations (embedding), which in turn can be calibrated (fine-tuned) into a variety of downstream tasks.

Recently, the transformers had also started invading the computer vision field. Among the articles which used Transformers for different computer vision tasks are An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, TransGan, Pretrained Image Transformer, and DETR. Lately, Transformers are also used for video processing (Knowledge Fusion Transformers for Video Action Classification). Usually in computer vision, the Transformers inputs are image patches.

Yet, there are several challenges that prevent a wider usage of Transformers in the visual domain:

  • The inherent local dependencies that exist in images
    Convolution networks traditionally "stars" on almost every computer vision task. They calculate lower layers' features based on nearby pixels and consequently make use of the local dependencies (connections) which exist inside images. Transformers, however, don't allow such local representations since the data representation is built through simultaneously analyzing all of the data connections. This difficulty can be overcome with a sophisticated weight initialization mechanism. Some used convolution layers as a first stage to build patch representations before feeding them into the Transformer.
  • Squared computation complexity w.r.t the transformer input
    The Transformer builds data representation by analyzing the connections between all the input parts using a mechanism called "Self-Attention" (SA) - the transformer heart. In SA, a calculation for M lengthed input has a complexity of O(M2). It becomes problematic in computation time and space for high-resolution images due to a large number of patches. In the past two years, several articles have suggested computationally cheaper variants of the Transformer, such as Reformer, Linformer, and Performer. However (and AFAIK), these versions have not even reached the classic transformer performance on downstream tasks.

The Article in Essence:

The squared complexity of the Transformer originates in the Self Attention (SA) mechanism. It occurs because of multiplication, which will be marked as L, of the interchangeable matrices Q=Q'X and K=K'X, where Q', K' are Query and Key matrices, and X is the transformer input. The size of the matrices Q and K are MxD when M is the input length and D is the data representation dimension (embedding size). It's easy to see then how the SA computationally of O(M2) originated. Remember that the output of SA is calculated as LV, where V=V'X and V' is the Value matrix.

In contrast with most of the other articles, which offer computationally cheap alternatives for the Transformer using different approximations for the SA mechanism, this article proposes a different approach to attack this problem. They suggest learning (!) the Q matrix instead of calculating it out of the input. Doing so, Q can be significantly smaller than the input length M, and the complexity of multiplying Q in K will not be squared but linear - O(MN).

The basic idea:

The article proposes calculating Q as Q'A where A is a learned matrix, called a latent array. K and V matrices are calculated similarly to the original SA mechanism. Later, instead of calculating the Self-Attention expression for the input X, the article calculates Cross-Attention between the input X and the latent array A. The latent array A length is significantly smaller than the input size. Hence, the squared complexity is prevented.

Note: The Cross-Attention mechanism was first presented in the BERT article. It was used to calculate the connectivity between BERT's encoder input and the decoder's intermediate output on tasks such as automatic translation or text summarization.

The Perceiver is an architecture based on attentional principles that scales to high-dimensional inputs such as images, videos, audio, point-clouds (and multimodal combinations) without making any domain-specific assumptions. The Perceiver uses a cross-attention module to project an input high-dimensional byte array to a fixed-dimensional latent bottleneck (M  N) before processing it using a stack of transformers in the low-d latent space. The Perceiver iteratively attends to the input byte array by alternating cross-attention and latent transformer blocks.

Detailed Explanation:

In Cross-Attention (CA), the matrices K and V are built similarly to SA by multiplying the input with learned V' and K', respectively. Since the squared complexity (w.r.t input size) limitation is removed, we can use longer series than the standard transformer. For example, when the Transformer input is a high-resolution image, it is commonly divided into 16x16 pixel patches due to its complexity limitation. Using a variable-sized latent array A overcomes this limitation and enables us to use longer input series. The article suggests flattening the input, turning it into a byte-array, before multiplying it with the Key and Value matrices. If the input is an image, each item in the byte array contains the pixel value.

The Perceiver input can also be a long audio or video series. Furthermore, the article argues that the Perceiver input can be a combination of audio and video together, which was impossible in previous Transformer versions because they required architecture adjustments according to the input types. The Perceiver architecture is input agnostic. Impressive!

The Perceiver architecture - detailed

Now that we have understood the basic principles of the Perceiver architecture, we can dive into the details. After calculating the Cross-Attention between the latent array and the input, the CA output is fed into a classic Transformer. In the article, they name it Latent Transformer (LTr).
The output size of the CA does not depend on the input size but on the latent array size, which is set according to the available computation resources. Since the latent array size is usually much smaller than the original input size, passing it through the LTr has a reasonable complexity. The LTr architecture is similar to the GPT-2 architecture and is formed of the original Transformer decoder.

LTr output is fed again into the CA mechanism, similar to the original input (the same K  and V matrices are reused). The CA output is then fed to an additional LTr. By repeating this CA + LTr, one can build a deep and robust architecture for strong input representation construction. The "LTr"s can share identical weights, have different weights in each layer, or combine the two. The "LTr"s can share identical weights, have different weights in each layer, or somehow combine the two (e.g., three sets of weights used all over). Think of the Perceiver as a multi-layer neuron net where every layer is composed of CA and LTr.

Intuition Corner:

The latent array can be seen as a set of learned questions about the input. An example of a potential query can be: relationships measurement between a patch p in the center of an image to all the other internal patches in a larger patch that contains p (in the first CA layer). The latent array in deeper layers of the Perceiver depends on the calculated values from previous layers, and similarly to convolution networks, these arrays try to revaluate the semantic features of the image. The Perceiver also resembles RNN, where each layer receives the whole input.

Positional encoding:

Self-Attention and Cross-Attention are agnostic to the input items order. The v representation would remain the same also after permuting the input series. Clearly, there are cases where the item order is important, such as natural language, images, video, audio, etc.

The series items' positional information is added to the CA and SA through positional encoding (PE). PE encodes the relative position of each item in the input series. In the article, they are using for the CA mechanism Fourier transformation-based PE, similar to the ones used in BERT. In contrast, they use learned PE for the LTr's SA mechanism.

The topic of positional encoding is extensively discussed in the article. The authors have introduced several fascinating changes and tried to give an intuition for the performance improvement reason.

Achievements:

The article compared the Perceiver representations embedding to several other self-supervised training methods (combined with a linear layer for classification) and supervised SOTA methods on different domains: Images, Video, Audio, Audio + Video, point clouds.

We train the Perceiver architecture on images from ImageNet (left), video and audio from AudioSet (considered both multi- and uni-modally) (center), and 3D point clouds from ModelNet40 (right). Essentially no architectural changes are required to use the model on a diverse range of input data.

In all these domains, the Perceiver performed better than other unsupervised methods they checked, including transformer-based ones. Some of these methods were built for specific domain data by using the inherent characteristics of the data (such as ResNet in the image domain). Yet, the Perceiver scored slightly worse in the supervised methods when these methods used these inherent characteristics.

Top-1 validation accuracy (in %) on ImageNet. Methods shown in red exploit domain-specific grid structure, while methods in blue do not. The first block reports standard performance from pixels – these numbers are taken from the literature. The second block shows performance when the inputs are RGB values concatenated with Fourier features (FF) of the xy positions – the same that the Perceiver receives. This block uses our implementation of the baselines. The Perceiver is competitive with standard baselines on ImageNet while not relying on domain-specific architectural assumptions.

P.S.

This remarkably interesting article suggests a cool method to overcome the transformer's squared complexity. Thar suggested architecture is agnostic to the input structure and can be used as it is to construct data representation in different domains.

Also, check out Yam Peleg's Keras implementation here:

Minimal keras implementation: “Perceiver: General Perception with Iterative Attention. Jaegle et al”
Minimal keras implementation: “Perceiver: General Perception with Iterative Attention. Jaegle et al” - perceiver.py

#deepnightlearners

This post was written by Michael (Mike) Erlihson, Ph.D.

Michael works in the cybersecurity company Salt Security as a principal data scientist. Michael researches and works in the deep learning field while lecturing and making scientific material more accessible to the public audience.

Mike Erlihson

Israel
Mike is a principal data scientist at the cybersecurity company Salt Security. He researches and works in deep learning while lecturing and making scientific material more accessible to the public.