A pedagogical implementation of TinyLlama from TinyLlama: An Open-Source Small Language Model, in PyTorch.
from tinyllama import TinyLlama, TinyLlamaConfiguration
# As specified in the paper.
configuration = TinyLlamaConfiguration(
embedding_dimension=2048,
intermediate_dimension=5632, # x2.75.
number_of_heads=16,
number_of_layers=22,
vocabulary_size=32_000,
context_length=2048,
)
model = TinyLlama(configuration=configuration)
tokens = torch.tensor([[1, 2, 3, 4]])
logits = model(tokens, mask=None)
- Implement caching (RoPE, KV).
- Switch to Flash Attention 2 for GQA.