Skip to content

Training Demo¤

This is a quick start demo to get training a SAE right away. All you need to do is choose a few hyperparameters (like the model to train on), and then set it off.

In this demo we'll train a sparse autoencoder on all MLP layer outputs in GPT-2 small (effectively training an SAE on each layer in parallel).

Setup¤

Imports¤

# Check if we're in Colab
try:
    import google.colab  # noqa: F401 # type: ignore

    in_colab = True
except ImportError:
    in_colab = False

#  Install if in Colab
if in_colab:
    %pip install sparse_autoencoder transformer_lens transformers wandb

# Otherwise enable hot reloading in dev mode
if not in_colab:
    %load_ext autoreload
    %autoreload 2
import os

from sparse_autoencoder import (
    ActivationResamplerHyperparameters,
    AutoencoderHyperparameters,
    Hyperparameters,
    LossHyperparameters,
    Method,
    OptimizerHyperparameters,
    Parameter,
    PipelineHyperparameters,
    SourceDataHyperparameters,
    SourceModelHyperparameters,
    SweepConfig,
    sweep,
)

os.environ["WANDB_NOTEBOOK_NAME"] = "demo.ipynb"

Hyperparameters¤

Customize any hyperparameters you want below (by default we're sweeping over l1 coefficient and learning rate).

Note we are using the RANDOM sweep approach (try random combinations of hyperparameters), which works surprisingly well but will need to be stopped at some point (as otherwise it will continue forever). If you want to run pre-defined runs consider using Parameter(values=[0.01, 0.05...]) for example rather than Parameter(max=0.03, min=0.008) for each parameter you are sweeping over. You can then set the strategy to Method.GRID.

def train_gpt_small_mlp_layers(
    expansion_factor: int = 4,
    n_layers: int = 12,
) -> None:
    """Run a new sweep experiment on GPT 2 Small's MLP layers.

    Args:
        expansion_factor: Expansion factor for the autoencoder.
        n_layers: Number of layers to train on. Max is 12.

    """
    sweep_config = SweepConfig(
        parameters=Hyperparameters(
            loss=LossHyperparameters(
                l1_coefficient=Parameter(max=0.03, min=0.008),
            ),
            optimizer=OptimizerHyperparameters(
                lr=Parameter(max=0.001, min=0.00001),
            ),
            source_model=SourceModelHyperparameters(
                name=Parameter("gpt2"),
                cache_names=Parameter(
                    [f"blocks.{layer}.hook_mlp_out" for layer in range(n_layers)]
                ),
                hook_dimension=Parameter(768),
            ),
            source_data=SourceDataHyperparameters(
                dataset_path=Parameter("alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"),
                context_size=Parameter(256),
                pre_tokenized=Parameter(value=True),
                pre_download=Parameter(value=False),  # Default to streaming the dataset
            ),
            autoencoder=AutoencoderHyperparameters(
                expansion_factor=Parameter(value=expansion_factor)
            ),
            pipeline=PipelineHyperparameters(
                max_activations=Parameter(1_000_000_000),
                checkpoint_frequency=Parameter(100_000_000),
                validation_frequency=Parameter(100_000_000),
                max_store_size=Parameter(1_000_000),
            ),
            activation_resampler=ActivationResamplerHyperparameters(
                resample_interval=Parameter(200_000_000),
                n_activations_activity_collate=Parameter(100_000_000),
                threshold_is_dead_portion_fires=Parameter(1e-6),
                max_n_resamples=Parameter(4),
            ),
        ),
        method=Method.RANDOM,
    )

    sweep(sweep_config=sweep_config)

Run the sweep¤

This will start a sweep with just one agent (the current machine). If you have multiple GPUs, it will use them automatically. Similarly it will work on Apple silicon devices by automatically using MPS.

train_gpt_small_mlp_layers()

Want to speed things up? You can trivially add extra machines to the sweep, each of which will peel of some runs from the sweep agent (stored on Wandb). To do this, on another machine simply run:

pip install sparse_autoencoder
join-sae-sweep --id=SWEEP_ID_SHOWN_ON_WANDB