joelburget / mamba-sae

Training and evaluating Sparse Autoencoders for Mamba

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Mamba Interpretability

This repo is for doing interpretability work on Mamba (Linear-Time Sequence Modeling with Selective State Spaces ). We follow the approach from Anthropic's Towards Monosemanticity: Decomposing Language Models With Dictionary Learning, though the scope might broaden in the future.

We make heavy use of the nnsight library for interpreting neural networks and the dictionary_learning library for training and understanding SAEs.

Tools

Main scripts

Before running any of these scripts, set the appropriate parameters in params.py.

train_model.py

Script for training a one-layer Mamba model. Outputs stats to wandb.

train_sae.py

Once you've trained your model, you can train a Sparse Autoencoder on it. This script actually trains a grid of SAEs, one for each combination of sparsity penalty and relative size you configure.

evaluate_saes.py

For each autoencode trained in the previous step, evaluate stats such as MSE loss, percentage of neurons alive, percentage of loss recovered, etc.

analyze_sae.py

Given a single SAE, find top activations for each neuron. Hopefully more features in the future.

Other

sae_analyze_320.py / sae_analyze_640.py

Scripts to plot heatmaps of stats for two sets of SAEs I trained (comparing four different sparsity penalties and four different relative sizes).

run_model.py

Helpers for running a model you've trained.

Pretrained Models / SAEs

I've uploaded two models (pytorch_model-{320,640}.bin) and two sets of SAEs trained on them to Google Drive. Doc with their stats.

About

Training and evaluating Sparse Autoencoders for Mamba


Languages

Language:Python 100.0%