Vision Transformer

A sstudy about Vision Transformer
type: insightlevel: medium

In computer vision, we usually approach the problem by using convolutional architectures due to how effective they are. From small scale to large scale, CNN-based methods remain dominant. Meanwhile, Transformer is a go-to method to work with natural language processing (NLP) because it is highly computational efficient and scalable. Nowadays we can train 500B parameters with self-attention-based architecture. Inspired from NLP success, Vision Transformer (ViT) [1] is a novel approach to tackle computer vision using Transformer encoder with minimal modifications. While small and middle-size dataset are ViT's weakness, further experiment show that ViT performs well and easily achieve state-of-the-art results. Furthermore, with pre-trained model on ImageNet-21K, ViT now can easily beats SoTA on small dataset like CIFAR10. In this blog, we will introduce the basic idea of Vision Transformer.

Vision Transformer (ViT)

Figure 1: Model overview. An image input is split into multiple patches, then we see them as normal token that will be embedded and feed into the encoder. To perform classification task, an extra learnable parameter is added at the beginning of the sequence, and and MLP head at its output position.

Unlike normal Transformer, which needs encoder and decoder, Vision Transformer only need the encoder part with a little modifications because our main task is image classification. The first modification is to split the input images into several patches, then feed them using multilayer perceptron or a simple CNN model to encode them. The second modification is to attach a [class] token at the beginning of the embedding input sequence, its output will behave like the image representation. Formally, we reshape the image xRH×W×C(xp)RN×(P2C)\mathbf{x} \in \mathbb{R}^ {H \times W \times C} \rightarrow (\mathbf{x}_{p}) \in \mathbb{R}^{N \times\left(P^{2} \cdot C\right)}, with xpR(P2C)\mathbf{x}_{p} \in \mathbb{R}^{\left(P^{2} \cdot C\right)}, where:

  • (H,W)(H, W) is the resolution of the original image,
  • CC is the number of channels,
  • (P,P)(P, P) is the resolution of each image patch,
  • N=HW/P2N=H W / P^{2} is the resulting number of patches

Similar to BERT's [class] token, a learnable embedding z00=xclassz_0^0 = \mathbf{x}_{\text{class}} is added at the beginning of the embedded patches sequence. And after LL encoder layers z00encoderencoderLz0Lz_0^0 \rightarrow \underbrace{\text{encoder} \rightarrow \text{encoder} \rightarrow \cdots}_{L} \rightarrow z_0^L, the final z0Lz_0^L is the image representation. During pretraining and fine-tuning, a MLP to classify z0Lz_0^L is added at its position. Position embedding is the standard 1-D Positional Embedding. Note that all MLP in the encoder consists of 2 layers with a GELU non-linearity. Formally, the ViT encoder is:

z0=[xclass ;xp1E;xp2E;;xpNE]+Epos ,ER(P2C)×D,Epos R(N+1)×D\mathbf{z}_{0} =\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_{p}^{1} \mathbf{E} ; \mathbf{x}_{p}^{2} \mathbf{E} ; \cdots ; \mathbf{x}_{p}^{N} \mathbf{E}\right] + \mathbf{E}_{\text {pos }}, \mathbf{E} \in \mathbb{R}^{\left(P^{2} \cdot C\right) \times D}, \mathbf{E}_{\text {pos }} \in \mathbb{R}^{(N+1) \times D}
z=MSA(LN(z1))+z1,=1L\mathbf{z}_{\ell}^{\prime} = \operatorname{MSA} \left(\operatorname{LN}\left(\mathbf{z}_{\ell-1}\right)\right)+\mathbf{z}_{\ell-1}, \ell=1 \ldots L
z=MLP(LN(z))+z,=1L\mathbf{z}_{\ell} = \operatorname{MLP}\left(\operatorname{LN}\left(\mathbf{z}_{\ell}^{\prime}\right)\right)+\mathbf{z}_{\ell}^{\prime}, \ell=1 \ldots L
y=LN(zL0)\mathbf{y} =\operatorname{LN}\left(\mathbf{z}_{L}^{0}\right)

Where LN()\operatorname{LN} (\cdot) is LayerNorm, MSA()\operatorname{MSA}(\cdot) is multihead self attention, E\mathbf{E} is the embedding matrix, z\mathbf{z} is the embedded vector and DD is the hidden dimension. The more visual illustration about the model is shown in Figure 1.

Vision Transformer's variants

There are three main variants: Base, Large, and Huge. The detail is shown in Table 1. When we are picking which backbone or variants to use, usually it is denoted as ViT - <size> / <patch size>, for example ViT-B/8 means the Base variants with 8×88 \times 8 patch size.

ModelLayersHidden size DDMLP sizeHeadsParams
ViT - Base1276830721286M
ViT - Large241024409616307M
ViT - Huge321280512016632M

Table 1: Detais of ViT model variants.

While increasing the model size usually help increasing the performance on large dataset, it is not the case for Vision Transformer. As shown in Table 2, some variants perform worse than the Base architecture. Therefore, we need to keep in mind about this characteristic when working with ViT.

Pretrained dataDataViT-B/16ViT-B/32ViT-L/16ViT-L/32ViT-H/14
ImageNetCIFAR-1098.1397.7797.8697.94-
CIFAR-10087.1386.3186.3587.04-
Oxford Flowers-10289.4985.4389.6686.36-
Oxford-IIIT-Pets93.8192.0493.6491.35-
ImageNet-21kCIFAR-1098.9598.7999.1699.1399.27
CIFAR-10091.6791.9893.4493.0493.82
Oxford Flowers-10299.3899.1199.6199.1999.51
Oxford-IIIT-Pets94.4393.0294.7393.0994.82

Table 2: Top1 accuracy (in %) of ViT variants on various datasets when pretrained on ImageNet and ImageNet-21k.

Reference

[1] Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., & Zhai, X. (2021). Thomas Unterthiner Mostafa Dehghani Matthias Minderer Georg Heigold Sylvain Gelly Jakob Uszkoreit and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations.