thevasudevgupta / bigbird

Google's BigBird (Jax/Flax & PyTorch) @ 🤗Transformers

Home Page:https://www.youtube.com/watch?v=G22vNvHmHQ0

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

BigBird

This repository tracks my work related to porting Google's BigBird to 🤗 Transformers. I trained 🤗's BigBirdModel & FlaxBigBirdModel (with suitable heads) on some of datasets mentioned in the paper: Big Bird: Transformers for Longer Sequences. This repository hosts scripts for those training as well.

You can find the quick demo in 🤗spaces: https://hf.co/spaces/vasudevgupta/BIGBIRD_NATURAL_QUESTIONS

Checkout following notebooks for diving deeper into using 🤗 BigBird:

Description Notebook
Flax BigBird evaluation on natural-questions dataset Open In Colab
PyTorch BigBird evaluation on natural-questions dataset Open In Colab
PyTorch BigBirdPegasus evaluation on PubMed dataset Open In Colab
How to use 🤗's BigBird (RoBERTa & Pegasus) for inference Open In Colab

Updates @ 🤗

Description Dated Link
Script for training FlaxBigBird (with QA heads) on natural-questions June 25, 2021 PR #12233
Added Flax/Jax BigBird-RoBERTa to 🤗Transformers June 15, 2021 PR #11967
Added PyTorch BigBird-Pegasus to 🤗Transformers May 7, 2021 PR #10991
Published blog post @ 🤗Blog March 31, 2021 Link
Added PyTorch BigBird-RoBERTa to 🤗Transformers March 30, 2021 PR #10183

Training BigBird

I have trained BigBird on natural-questions dataset. This dataset takes around 100 GB of space on a disk. Before diving deeper into scripts, let's set up the system using the following commands:

# clone my repository
git clone https://github.com/vasudevgupta7/bigbird

# install requirements
cd bigbird
pip3 install -r requirements.txt

# switch to code directory
cd src

# create data directory for preparing natural questions
mkdir -p data

Now that your system is ready let's preprocess & prepare the dataset for training. Just run the following commands:

# this will download ~ 100 GB dataset from 🤗 Hub & prepare training data in `data/nq-training.jsonl`
PROCESS_TRAIN=true python3 prepare_natural_questions.py

# for preparing validation data in `data/nq-validation.jsonl`
PROCESS_TRAIN=false python3 prepare_natural_questions.py

The above commands will first download the dataset from 🤗 Hub & then will prepare it for training. Remember this will download ~ 100 GB of the dataset, so you need to have a good internet connection & enough space (~ 250 GB free space). Preparing the dataset will take ~ 3 hours.

Now that you have prepared the dataset let's start training. You have two options here:

  1. Train PyTorch version of BigBird with 🤗 Trainer
  2. Train FlaxBigBird with custom training loop

PyTorch BigBird distributed training on multiple GPUs

# For distributed training (using nq-training.jsonl & nq-validation.jsonl) on 2 gpus
python3 -m torch.distributed.launch --nproc_per_node=2 train_nq_torch.py

Flax BigBird distributed training on TPUs/GPUs

# start training
python3 train_nq_flax.py

# For hparams tuning, try wandb sweep (`random search` is happening by default):
wandb sweep sweep_flax.yaml
wandb agent <agent-id-created-by-above-CMD>

You can find my fine-tuned checkpoints on HuggingFace Hub. Refer to the following table:

Checkpoint Description
flax-bigbird-natural-questions Obtained by running train_nq_flax.py script
bigbird-roberta-natural-questions Obtained by running train_nq_torch.py script

To see how the above checkpoint performs on the QA task, check out this:

Context is just a tweet taken from 🤗 Twitter Handle. 💥💥💥

About

Google's BigBird (Jax/Flax & PyTorch) @ 🤗Transformers

https://www.youtube.com/watch?v=G22vNvHmHQ0

License:MIT License


Languages

Language:Jupyter Notebook 94.9%Language:Python 5.1%