Run an sweep on all layers of GPT2 Small.
Command:
git clone https://github.com/ai-safety-foundation/sparse_autoencoder.git && cd sparse_autoencoder &&
poetry env use python3.11 && poetry install &&
poetry run python sparse_autoencoder/training_runs/gpt2.py
train()
Train.
Source code in sparse_autoencoder/training_runs/gpt2.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69 | def train() -> None:
"""Train."""
sweep_config = SweepConfig(
parameters=Hyperparameters(
loss=LossHyperparameters(
l1_coefficient=Parameter(values=[0.0001]),
),
optimizer=OptimizerHyperparameters(
lr=Parameter(value=0.0001),
),
source_model=SourceModelHyperparameters(
name=Parameter("gpt2"),
cache_names=Parameter(
value=[f"blocks.{layer}.hook_mlp_out" for layer in range(12)]
),
hook_dimension=Parameter(768),
),
source_data=SourceDataHyperparameters(
dataset_path=Parameter("alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"),
context_size=Parameter(256),
pre_tokenized=Parameter(value=True),
pre_download=Parameter(value=True),
# Total dataset is c.7bn activations (64 files)
# C. 1.5TB needed to store all activations
dataset_files=Parameter(
[f"data/train-{str(i).zfill(5)}-of-00064.parquet" for i in range(20)]
),
),
autoencoder=AutoencoderHyperparameters(expansion_factor=Parameter(values=[32, 64])),
pipeline=PipelineHyperparameters(),
activation_resampler=ActivationResamplerHyperparameters(
threshold_is_dead_portion_fires=Parameter(1e-5),
),
),
method=Method.GRID,
)
sweep(sweep_config=sweep_config)
|