vpj / jax_transformer

Autoregressive transformer in JAX from scratch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

This implementation builds a transformer decoder from ground up. This doesn't use any higher level frameworks like Flax and I have used labml for logging and experiment tracking.

I have implemented a simple Module class to build basic building blocks upon.

This was my first JAX project and many implementations were taken from PyTorch implementations at nn.labml.ai.

JAX can optimize and differentiate Python pure-functions. Pure functions are function that take a bunch of arguments and return a result without making changes to anything like local variables. JAX can also compile these functions to as well as vectorize to run them efficiently.

In JAX you don't have to worry about the batches. The functions are implemented for a single sample and jax.vit can vectorize (parallelize) the functions across the batch dimension (or any other dimension if needed).

Contents

View Run Twitter thread

About

Autoregressive transformer in JAX from scratch


Languages

Language:Python 100.0%