Eric-Wallace / JAXSeq

Train very large language models in Jax.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

JaxSeq

Overview

Built on top of HuggingFace's Transformers library, JaxSeq enables training very large language models in Jax. Currently it supports GPT2, GPTJ, T5, and OPT models. JaxSeq is designed to be light-weight and easily extensible, with the aim being to demonstrate a workflow for training large language models without with the heft that is typical other existing frameworks.

Thanks to Jax's pjit function, you can straightforwardly train models with arbitrary model and data parellelism; you can trade-off these two as you like. You can also do model parallelism across multiple hosts. Support for gradient checkpointing, gradient accumulation, and bfloat16 training/inference is provided as well for memory efficient training.

If you encounter an error or want to contribute, feel free to drop an issue!

installation

1. pull from github

git clone https://github.com/Sea-Snell/JAXSeq.git
cd JAXSeq
export PYTHONPATH=${PWD}/src/

2. install dependencies

Install with conda (cpu, tpu, or gpu).

install with conda (cpu):

conda env create -f environment.yml
conda activate JaxSeq

install with conda (gpu):

conda env create -f environment.yml
conda activate JaxSeq
python -m pip install --upgrade pip
python -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

install with conda (tpu):

conda env create -f environment.yml
conda activate JaxSeq
python -m pip install --upgrade pip
python -m pip install "jax[tpu]==0.3.21" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Workflow

We provide some example scripts for training and evaluating GPT2, GPTJ, OPT, and T5 models using JaxSeq. However you should feel free to build your own workflow for training. You can find these scripts in the examples/ directory. Each script takes as input a json file which should be of shape:

{
"train": [{"in_text": "something", "out_text": "something else"}, ...], 
"eval": [{"in_text": "something else else", "out_text": "something else else else"}, ...], 
}

This code was largely tested, developed, and optimized for use on TPU-pods, though it should also work well on GPU clusters.

Google Cloud Buckets

To further support TPU workflows the example scripts provide functionality for uploading / downloading data and or checkpoints to / from Google Cloud Storage buckets. This can be achieved by prefixing the path with gcs://. And depending on the permissions of the bucket, you may need to specify the google cloud project and provide an authentication token.

Other Excellent References for Working with Large Models in Jax

About

Train very large language models in Jax.

License:MIT License


Languages

Language:Python 99.5%Language:Shell 0.5%