VPeterV / annotated-s4

Implementation of https://srush.github.io/annotated-s4

Home Page:https://srush.github.io/annotated-s4

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Experiments

MNIST Sequence Modeling

# Default arguments
python -m s4.train --dataset mnist --model s4 --epochs 100 --bsz 128 --d_model 128 --ssm_n 64

QuickDraw Sequence Modeling

# Default arguments
python -m s4.train --dataset quickdraw --model s4 --epochs 10 --bsz 128 --d_model 128 --ssm_n 64

# "Run in a day" variant
python -m s4.train --dataset quickdraw --model s4 --epochs 1 --bsz 512 --d_model 256 --ssm_n 64 --p_dropout 0.05

MNIST Classification

# Default arguments
python -m s4.train --dataset mnist-classification --model s4 --epochs 10 --bsz 128 --d_model 128 --ssm_n 64

(Default Arguments, as shown above): Gets "best" 97.76% accuracy in 10 epochs @ 40s/epoch on a TitanRTX.

CIFAR-10 Classification

## Adding a Cubic Decay Schedule for last 70% of training

# Default arguments (100 epochs for CIFAR)
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 128 --d_model 128 --ssm_n 64 --lr 1e-2 --lr_schedule

# S4 replication from central repository
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 64 --d_model 512 --ssm_n 64 --lr 1e-2 --lr_schedule

## After Fixing S4-Custom Optimization & Dropout2D (all implemented inline now... can add flags if desired)

# Default arguments (100 epochs for CIFAR)
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 128 --d_model 128 --ssm_n 64 --lr 1e-2

# S4 replication from central repository
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 64 --d_model 512 --ssm_n 64 --lr 1e-2

## Before Fixing S4-Custom Optimization...

# Default arguments (100 epochs for CIFAR)
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 128 --d_model 128 --ssm_n 64

# S4 replication from central repository
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 64 --d_model 512 --ssm_n 64

Adding a Schedule:

  • (LR 1e-2 w/ Replication Args -- "big" model): 71.55% (still running, 39 epochs) @ 3m16s on a TitanRTX
  • (LR 1e-2 w/ Default Args -- not "bigger" model): 71.92% @ 36s/epoch on a TitanRTX

After Fixing Dropout2D (w/ Optimization in Place):

  • (LR 1e-2 w/ Replication Args -- "big" model): 70.68% (still running, 47 epochs) @ 3m17s on a TitanRTX
  • (LR 1e-2 w/ Default Args -- not "bigger" model): 68.20% @ 36s/epoch on a TitanRTX

After Fixing Optimization, Before Fixing Dropout2D:

  • (LR 1e-2 w/ Default Args -- not "bigger" model): 67.14% @ 36s/epoch on a TitanRTX

Before Fixing S4 Optimization -- AdamW w/ LR 1e-3 for ALL Parameters:

  • (Default Arguments): Gets "best" 63.51% accuracy @ 46s/epoch on a TitanRTX
  • (S4 Arguments): Gets "best" 66.44% accuracy @ 3m11s on a TitanRTX
    • Possible reasons for failure to meet replication: LR Schedule (Decay on Plateau), Custom LR per Parameter.

Quickstart (Development)

We have two requirements.txt files that hold dependencies for the current project: one that is tailored to CPUs, the other that installs for GPU.

CPU-Only (MacOS, Linux)

# Set up virtual/conda environment of your choosing & activate...
pip install -r requirements-cpu.txt

# Set up pre-commit
pre-commit install

GPU (CUDA > 11 & CUDNN > 8.2)

# Set up virtual/conda environment of your choosing & activate...
pip install -r requirements-gpu.txt

# Set up pre-commit
pre-commit install

Dependencies from Scratch

In case the above requirements.txt don't work, here are the commands used to download dependencies.

CPU-Only

# Set up virtual/conda environment of your choosing & activate... then install the following:
pip install --upgrade "jax[cpu]"
pip install flax
pip install torch torchvision torchaudio

# Defaults
pip install black celluloid flake8 google-cloud-storage isort ipython matplotlib pre-commit seaborn tensorflow tqdm

# Set up pre-commit
pre-commit install

GPU (CUDA > 11, CUDNN > 8.2)

Note - CUDNN > 8.2 is critical for compilation without warnings, and GPU w/ at least Turing architecture for full efficiency.

# Set up virtual/conda environment of your choosing & activate... then install the following:
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install flax
pip install torch==1.10.1+cpu torchvision==0.11.2+cpu torchaudio==0.10.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html

# Defaults
pip install black celluloid flake8 google-cloud-storage isort ipython matplotlib pre-commit seaborn tensorflow tqdm

# Set up pre-commit
pre-commit install

About

Implementation of https://srush.github.io/annotated-s4

https://srush.github.io/annotated-s4


Languages

Language:Python 97.8%Language:Makefile 2.2%