TJKlein / Essential-Transformer

A minimalist 45 minutes implementation of the transformer backbone (encoder, decoder)

Repository from Github https://github.comTJKlein/Essential-TransformerRepository from Github https://github.comTJKlein/Essential-Transformer

The Essential Transformer

Understanding the backbone of encoders and decoders in 45-minutes

The Transformer architecture has revolutionized natural language processing and machine translation. This GitHub repository provides a minimalist yet comprehensive implementation of the Transformer architecture's encoder and decoder components, aimed at providing an intuitive understanding of the core concepts underlying this powerful model. The implementations serve as a didactic resource for enthusiasts, researchers, and learners who wish to grasp its fundamental principles. Each implementation needs less than 100 lines of code.

To keep things simple, a couple of assumptions are made:

  • positional embeddings are treated as trainable that are added to the token embeddings
  • the embedding dimensionality must be a multiple of the number of heads (the joint embedding is reshaped before softmax normalization)
  1. Toy example of instantiating a decoder block:
import torch
from decoder import Transformer, TransformerBlock

# Some toy parameters
num_heads = 16
emb_dim = 768
ffn_dim = 1024
num_layers = 12
max_len = 128
vocab_sz = 10000
batch_sz = 10

# Toy input data corresponding to embeddings
x = torch.randn((batch_sz,max_len,emb_dim))

tb = TransformerBlock(max_len, emb_dim, ffn_dim, num_heads)
tb(x)
  1. Toy example of instantiating a transformer decoder:
import torch
from decoder import Transformer

# Some toy parameters
num_heads = 16
emb_dim = 768
ffn_dim = 1024
num_layers = 12
max_len = 128
vocab_sz = 10000
batch_sz = 10

# Toy input data corresponding to random tokens
x = torch.randint(0,vocab_sz,(batch_sz, max_len))

trans = Transformer(num_layers, num_heads, max_len, vocab_sz, emb_dim, ffn_dim)
trans(x)
  1. Toy example of instantiating a transformer decoder with multi-query attention:
import torch
from decoder_multi_query_attention import Transformer

# Some toy parameters
num_heads = 16
emb_dim = 768
ffn_dim = 1024
num_layers = 12
max_len = 128
vocab_sz = 10000
batch_sz = 10

# Toy input data corresponding to random tokens
x = torch.randint(0,vocab_sz,(batch_sz, max_len))

trans = Transformer(num_layers, num_heads, max_len, vocab_sz, emb_dim, ffn_dim)
trans(x)
  1. Toy example of instantiating a transformer encoder:
import torch
from encoder import Transformer

num_heads = 16
emb_dim = 768
ffn_dim = 1024
num_layers = 12
max_len = 128
vocab_sz = 10000
batch_sz = 10
# Toy input data corresponding to random tokens
x = torch.randint(0,vocab_sz,(batch_sz, max_len))

trans = Transformer(num_layers, vocab_sz, emb_dim, max_len, num_heads, ffn_dim)
trans(x)

About

A minimalist 45 minutes implementation of the transformer backbone (encoder, decoder)

License:MIT License


Languages

Language:Python 100.0%