Transformers from scratch in pytorch

Published

October 25, 2022

Notes from https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html

## Standard Libraries
import os
from tracemalloc import Snapshot
import numpy as numpy
import random
import math
import json
from functools import partial
import logging
import sys

# imports for plotting
import matplotlib.pyplot as plt
import matplotlib_inline
plt.set_cmap('cividis')
%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats()
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

from tqdm.notebook import tqdm

# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# others
import einops

# path to folder where datasets should be downloaded
DATASET_PATH = "../data"
CHECKPOINT_PATH = "../saved_models"

# setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministirc on GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device: ", device)
Global seed set to 42
Device:  cpu

What is Attention?

The attention mechanism describes a weighted average of sequence of elements with the weights dynamically computed based on the input query and the element’s keys.

image As we can see above, Query, Key and Value are all just modified word embedding.

image

(image from http://jalammar.github.io/illustrated-transformer/) #### Self-attention We use self-attention to modify each embedding of the input word as a combination of modified values from the word embedding and the weights computed. Some intuitions

  • Values are nothing but modified word embeddings which will be weighted and given as outputs
  • Query is the modified word embedding which will search for relevant word embeddings in the sequence
  • Keys are modified word embeddings which will be searched by the query to create the weights
  • As we are using dot product, we are essentially computing the cosine similarity between the query and the key. The higher the cosine similarity, the higher the weight.

If we look at the softmax scores, clearly the word at its own position will have the highest softmax score, but sometimes it’s useful to attend to another word that is relevant to the current word.

Lets look at the statement The animal didn't cross the street because it was too tired. In this case, softmax for ‘it’ will have a parts of itself, animal (more) and street (less)

In addition to attending to other parts of sentence, the softmax can also drown out the irrelvant words by multiplying them with a very small number.

image

Multi Head Attention

An attention layer outputs a representation of the input sequence, based on the weights learnt for Query, Key and Value. We use multiple heads and combine them using a linear layer to get a better representation of the input sequence.

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits/math.sqrt(d_k) 
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask==0, -9e15)
    attention = F.softmax(attn_logits, dim=1)
    values = torch.matmul(attention, v)
    return values, attention
seq_len, d_k = 3, 2
q = torch.randn(seq_len, d_k)
k = torch.randn(seq_len, d_k)
v = torch.randn(seq_len, d_k)
values, attention = scaled_dot_product(q, k, v)
print("Q\n", q.size())
print("K\n", k.size())
print("V\n", v.size())
print("Values\n", values.size())
print("Attention\n", attention.size())
Q
 torch.Size([3, 2])
K
 torch.Size([3, 2])
V
 torch.Size([3, 2])
Values
 torch.Size([3, 2])
Attention
 torch.Size([3, 3])
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, log_level=logging.INFO):
        super().__init__()
        assert embed_dim%num_heads == 0, "Embedding dimension should be zero module with number of heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim//num_heads
        # logging.getLogger().setLevel(logging.DEBUG)
        self.logger = logging.getLogger(self.__class__.__name__)
        # print(f'log_level - {log_level}')
        self.logger.setLevel(log_level)

        # stack all weight matrices 1...h together for efficiency
        # Note that in many implemenations, you see bias=false, which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim) 
        self.o_proj = nn.Linear(embed_dim, embed_dim)
        self._reset_parameters()

    def _reset_parameters(self):
        # Orignial Transformer initialization
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size()
        qkv = self.qkv_proj(x)
        self.logger.debug(f'x.shape - {x.size()}. qkv_proj.shape - {self.qkv_proj.weight.size()}')

        # separate Q, K, V from the linear output
        self.logger.debug(f'generated qkv.shape - {qkv.size()}')
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        self.logger.debug(f'reshaped qkv.shape - {qkv.size()}')
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, seqlen, Head, DIms ]
        self.logger.debug(f'permuted qkv.shape - {qkv.size()}')
        q, k, v = qkv.chunk(3, dim=-1)

        # determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        self.logger.debug(f'values.shape - {values.size()}. attension.shape - {attention.size()}')
        values = values.permute(0, 2, 1, 3) # [Batch, seqlen, Head, dims]
        self.logger.debug(f'values.shape after permute - {values.size()}')
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        self.logger.debug(f'values.shape after reshape - {values.size()}')

        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o

input_dim = 64
embed_dim = 16
batch_size = 3
seq_len = 10
multi_attention = MultiHeadAttention(input_dim, embed_dim, num_heads=4, log_level=logging.DEBUG)
x = torch.randn(batch_size, seq_len, input_dim)
att = multi_attention(x)
DEBUG:MultiHeadAttention:x.shape - torch.Size([3, 10, 64]). qkv_proj.shape - torch.Size([48, 64])
DEBUG:MultiHeadAttention:generated qkv.shape - torch.Size([3, 10, 48])
DEBUG:MultiHeadAttention:reshaped qkv.shape - torch.Size([3, 10, 4, 12])
DEBUG:MultiHeadAttention:permuted qkv.shape - torch.Size([3, 4, 10, 12])
DEBUG:MultiHeadAttention:values.shape - torch.Size([3, 4, 10, 4]). attension.shape - torch.Size([3, 4, 10, 10])
DEBUG:MultiHeadAttention:values.shape after permute - torch.Size([3, 10, 4, 4])
DEBUG:MultiHeadAttention:values.shape after reshape - torch.Size([3, 10, 16])
class EncoderBlock(nn.Module):
    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        super().__init__()

        # attention layer
        self.self_attn = MultiHeadAttention(input_dim, input_dim, num_heads)

        # two layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim)
        )

        # layers to apply inbetween main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Attention part
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)

        # MLP part
        linear_out = self.linear_net(x)
        x = x + self.dropout(linear_out)
        x = self.norm2(x)

        return x

class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for l in self.layers:
            x = l(x, mask=mask)
        return x
    
    def get_attention_maps(self, x, mask=None):
        attention_maps = []
        for l in self.layers:
            _, attn_map = l.self_attn(x, mask=mask, return_attention=True)
            attention_maps.append(attn_map)
            x = l(x)
        return attention_maps


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        # create a matrix of seq_len, hidden_dim representing the positinal encoding for 
        pe = torch.zeros(max_len, d_model)

        # position is index of the word in the sequence
        position = torch.arange(0,  max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, d_model, 2)) * -math.log(10000)/d_model
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # register_buffer => tensor which is not a parameter, but shold be a part of the modules state
        # used for tensors that need to be on the same device as the module
        # persistent=False tell pytorch not to add the buffer to the state_dict
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x
        
        

\[ f(n) = \begin{cases} n/2, & \text{if $n$ is even} \\ 3n+1, & \text{if $n$ is odd} \end{cases} \]

Positional Encoding

We add a fixed signal (not trainable) to each word based on its position. The dimension of the PE signal is same is same as the word dimension.

\[ PE_{(pos, i)} = \begin{cases} sin(\frac{pos}{10000^{i/d_{model}}}), & \text{if $i$ mod 2 = 0}\\ cos(\frac{pos}{10000^{i-1/d_{model}}}), & \text{otherwise} \end{cases} \]

Here pos is the word position and i is the embedding position.

For a deeper intuition, look at - https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ - https://towardsdatascience.com/master-positional-encoding-part-i-63c05d90a0c3