MClarke1991 / e2e_sae

Sparse Autoencoder Training Library

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

e2e_sae

This library is used to train and evaluate Sparse Autoencoders (SAEs). It handles the following training types:

  • e2e (end-to-end): Loss function includes sparsity and final model kl_divergence.
  • e2e + downstream reconstruction: Loss function includes sparsity, final model kl_divergence, and MSE at downstream layers.
  • local (i.e. vanilla SAEs): Loss function includes sparsity and MSE at the SAE layer
  • Any combination of the above.

See our paper which argues for training SAEs e2e rather than locally. All SAEs presented in the paper can be found at https://wandb.ai/sparsify/gpt2 and can be loaded using this library.

Usage

Installation

pip install e2e_sae

Train SAEs on any TransformerLens model

If you would like to track your run with Weights and Biases, place your api key and entity name in a new file called .env. An example is provided in .env.example.

Create a config file (see gpt2 configs here for examples). Then run

python e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py <path_to_config>

If using a Colab notebook, see this example.

Sample wandb sweep configs are provided in e2e_sae/scripts/train_tlens_saes/.

The library also contains scripts for training mlps and SAEs on mlps, as well as training custom transformerlens models and SAEs on these models (see here).

Load a Pre-trained SAE

You can load any pre-trained SAE (and accompanying TransformerLens model) trained using this library from Weights and Biases or locally by running

from e2e_sae import SAETransformer
model = SAETransformer.from_wandb("<entity/project/run_id>")
# or, if stored locally
model = SAETransformer.from_local_path("/path/to/checkpoint/dir") 

All runs in our paper can be loaded this way (e.g.sparsify/gpt2/tvj2owza).

This will instantiate a SAETransformer class, which contains a TransformerLens model with SAEs attached. To do a forward pass without SAEs, use the forward_raw method, to do a forward pass with SAEs, use the forward method (or simply call the SAETansformer instance).

The dictionary elements of an SAE can be accessed via SAE.dict_elements. This is will normalize the decoder elements to have norm 1.

Analysis

To reproduce all of the analysis in our paper use the scripts in e2e_sae/scripts/analysis/.

Contributing

Developer dependencies are installed with make install-dev, which will also install pre-commit hooks.

Suggested extensions and settings for VSCode are provided in .vscode/. To use the suggested settings, copy .vscode/settings-example.json to .vscode/settings.json.

There are various make commands that may be helpful

make check  # Run pre-commit checks on all files (i.e. pyright, ruff linter, and ruff formatter)
make type  # Run pyright on all files
make format  # Run ruff linter and formatter on all files
make test  # Run tests that aren't marked `slow`
make test-all  # Run all tests

This library is maintained by Dan Braun.

Join the Open Source Mechanistic Interpretability Slack to chat about this library and other projects in the space!

About

Sparse Autoencoder Training Library

License:MIT License


Languages

Language:Python 99.8%Language:Makefile 0.2%