neverix / saex

SAEs in Jax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

saex

Sparse autoencoders in Jax.

Running

# Train a small SAE on the GPT-2 residual stream. Requires at most 32GB of RAM.
python -m scripts.train_gpt2_sae --is_xl=False --save_steps=0 --sparsity_coefficient=1e-4
# Download GPT-2 residual stream SAEs for finetuning
scripts/download_jb_saes.sh
# Download Phi-3
wget 'https://huggingface.co/SanctumAI/Phi-3-mini-4k-instruct-GGUF/resolve/main/phi-3-mini-4k-instruct.fp16.gguf?download=true' -O weights/phi-3-16.gguf
# Generate data for a toy model
JAX_PLATFORMS=cpu python -m saex.toy_models

Tests (there aren't any yet):

poetry run pytest

How to install

sudo apt install -y make build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev xz-utils tk-dev libffi-dev liblzma-dev python-openssl git
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc
echo 'export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc

echo -e 'if command -v pyenv 1>/dev/null 2>&1; then\n eval "$(pyenv init -)"\nfi' >> ~/.bashrc
pyenv install 3.12.3
pyenv global 3.12.3
python3 -m pip install poetry
echo 'export PATH="$PYENV_ROOT/versions/3.12.3/bin:$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc

poetry env use 3.12
poetry lock
poetry install
poetry shell

I think it should be possible to set up a development environment without installing pyenv on tpu-ubuntu2204.

FAQ

No one actually asked these questions, but here are the answers anyway.

How is this code parallelized?

Data and tensor parallelism. In theory, the size of the SAE is unlimited. In practice, it is initialized on one device.

Are results comparable to SAELens?

Yes. I haven't tested with smaller batch sizes, but you can get comparable results for GPT2-Small Layer 9 with ~25% less tokens and ~3x lower training time.

What techniques does saex use?

TODOs

  • Anthropic's scaled sparsity loss
  • Autointerp
  • Dreaming

About

SAEs in Jax

License:MIT License


Languages

Language:Python 61.8%Language:Jupyter Notebook 38.0%Language:Shell 0.1%