Skip to content

Sweep config¤

Sweep config.

Default hyperparameter setup for quick tuning of a sparse autoencoder.

Warning

The runtime hyperparameter classes must be manually kept in sync with the hyperparameter classes, so that static type checking works.

ActivationResamplerHyperparameters dataclass ¤

Bases: NestedParameter

Activation resampler hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@dataclass(frozen=True)
class ActivationResamplerHyperparameters(NestedParameter):
    """Activation resampler hyperparameters."""

    resample_interval: Parameter[int] = field(
        default=Parameter(round_to_multiple(200_000_000, DEFAULT_STORE_SIZE))
    )
    """Resample interval."""

    max_n_resamples: Parameter[int] = field(default=Parameter(4))
    """Maximum number of resamples."""

    n_activations_activity_collate: Parameter[int] = field(
        default=Parameter(round_to_multiple(100_000_000, DEFAULT_STORE_SIZE))
    )
    """Number of steps to collate before resampling.

    Number of autoencoder learned activation vectors to collate before resampling.
    """

    resample_dataset_size: Parameter[int] = field(default=Parameter(DEFAULT_BATCH_SIZE * 100))
    """Resample dataset size.

    Number of autoencoder input activations to use for calculating the loss, as part of the
    resampling process to create the reset neuron weights.
    """

    threshold_is_dead_portion_fires: Parameter[float] = field(default=Parameter(0.0))
    """Dead neuron threshold.

    Threshold for determining if a neuron is dead (has "fired" in less than this portion of the
    collated sample).
    """

max_n_resamples: Parameter[int] = field(default=Parameter(4)) class-attribute instance-attribute ¤

Maximum number of resamples.

n_activations_activity_collate: Parameter[int] = field(default=Parameter(round_to_multiple(100000000, DEFAULT_STORE_SIZE))) class-attribute instance-attribute ¤

Number of steps to collate before resampling.

Number of autoencoder learned activation vectors to collate before resampling.

resample_dataset_size: Parameter[int] = field(default=Parameter(DEFAULT_BATCH_SIZE * 100)) class-attribute instance-attribute ¤

Resample dataset size.

Number of autoencoder input activations to use for calculating the loss, as part of the resampling process to create the reset neuron weights.

resample_interval: Parameter[int] = field(default=Parameter(round_to_multiple(200000000, DEFAULT_STORE_SIZE))) class-attribute instance-attribute ¤

Resample interval.

threshold_is_dead_portion_fires: Parameter[float] = field(default=Parameter(0.0)) class-attribute instance-attribute ¤

Dead neuron threshold.

Threshold for determining if a neuron is dead (has "fired" in less than this portion of the collated sample).

ActivationResamplerRuntimeHyperparameters ¤

Bases: TypedDict

Activation resampler runtime hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
65
66
67
68
69
70
71
72
class ActivationResamplerRuntimeHyperparameters(TypedDict):
    """Activation resampler runtime hyperparameters."""

    resample_interval: int
    max_n_resamples: int
    n_activations_activity_collate: int
    resample_dataset_size: int
    threshold_is_dead_portion_fires: float

AutoencoderHyperparameters dataclass ¤

Bases: NestedParameter

Sparse autoencoder hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
75
76
77
78
79
80
81
82
83
84
@dataclass(frozen=True)
class AutoencoderHyperparameters(NestedParameter):
    """Sparse autoencoder hyperparameters."""

    expansion_factor: Parameter[int] = field(default=Parameter(2))
    """Expansion Factor.

    Size of the learned features relative to the input features. A good expansion factor to start
    with is typically 2-4.
    """

expansion_factor: Parameter[int] = field(default=Parameter(2)) class-attribute instance-attribute ¤

Expansion Factor.

Size of the learned features relative to the input features. A good expansion factor to start with is typically 2-4.

AutoencoderRuntimeHyperparameters ¤

Bases: TypedDict

Autoencoder runtime hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
87
88
89
90
class AutoencoderRuntimeHyperparameters(TypedDict):
    """Autoencoder runtime hyperparameters."""

    expansion_factor: int

Hyperparameters dataclass ¤

Bases: Parameters

Sweep Hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
@dataclass
class Hyperparameters(Parameters):
    """Sweep Hyperparameters."""

    # Required parameters
    source_data: SourceDataHyperparameters

    source_model: SourceModelHyperparameters

    # Optional parameters
    activation_resampler: ActivationResamplerHyperparameters = field(
        default=ActivationResamplerHyperparameters()
    )

    autoencoder: AutoencoderHyperparameters = field(default=AutoencoderHyperparameters())

    loss: LossHyperparameters = field(default=LossHyperparameters())

    optimizer: OptimizerHyperparameters = field(default=OptimizerHyperparameters())

    pipeline: PipelineHyperparameters = field(default=PipelineHyperparameters())

    random_seed: Parameter[int] = field(default=Parameter(49))
    """Random seed."""

    def __post_init__(self) -> None:
        """Post initialisation checks."""
        # Check the resample dataset size <= the store size (currently only works if value is used
        # for both).
        if (
            self.activation_resampler.resample_dataset_size.value is not None
            and self.pipeline.max_store_size.value is not None
            and self.activation_resampler.resample_dataset_size.value
            > int(self.pipeline.max_store_size.value)
        ):
            error_message = (
                "Resample dataset size must be less than or equal to the pipeline max store size. "
                f"Resample dataset size: {self.activation_resampler.resample_dataset_size.value}, "
                f"pipeline max store size: {self.pipeline.max_store_size.value}."
            )
            raise ValueError(error_message)

    @final
    def __str__(self) -> str:
        """String representation of this object."""
        items_representation = []
        for key, value in self.__dict__.items():
            if value is not None:
                items_representation.append(f"{key}={value}")
        joined_items = "\n    ".join(items_representation)

        class_name = self.__class__.__name__

        return f"{class_name}(\n    {joined_items}\n)"

    @final
    def __repr__(self) -> str:
        """Representation of this object."""
        return self.__str__()

random_seed: Parameter[int] = field(default=Parameter(49)) class-attribute instance-attribute ¤

Random seed.

__post_init__() ¤

Post initialisation checks.

Source code in sparse_autoencoder/train/sweep_config.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def __post_init__(self) -> None:
    """Post initialisation checks."""
    # Check the resample dataset size <= the store size (currently only works if value is used
    # for both).
    if (
        self.activation_resampler.resample_dataset_size.value is not None
        and self.pipeline.max_store_size.value is not None
        and self.activation_resampler.resample_dataset_size.value
        > int(self.pipeline.max_store_size.value)
    ):
        error_message = (
            "Resample dataset size must be less than or equal to the pipeline max store size. "
            f"Resample dataset size: {self.activation_resampler.resample_dataset_size.value}, "
            f"pipeline max store size: {self.pipeline.max_store_size.value}."
        )
        raise ValueError(error_message)

__repr__() ¤

Representation of this object.

Source code in sparse_autoencoder/train/sweep_config.py
367
368
369
370
@final
def __repr__(self) -> str:
    """Representation of this object."""
    return self.__str__()

__str__() ¤

String representation of this object.

Source code in sparse_autoencoder/train/sweep_config.py
354
355
356
357
358
359
360
361
362
363
364
365
@final
def __str__(self) -> str:
    """String representation of this object."""
    items_representation = []
    for key, value in self.__dict__.items():
        if value is not None:
            items_representation.append(f"{key}={value}")
    joined_items = "\n    ".join(items_representation)

    class_name = self.__class__.__name__

    return f"{class_name}(\n    {joined_items}\n)"

LossHyperparameters dataclass ¤

Bases: NestedParameter

Loss hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@dataclass(frozen=True)
class LossHyperparameters(NestedParameter):
    """Loss hyperparameters."""

    l1_coefficient: Parameter[float] = field(default=Parameter(1e-3))
    """L1 Penalty Coefficient.

    The L1 penalty is the absolute sum of learned (hidden) activations, multiplied by this constant.
    The penalty encourages sparsity in the learned activations. This loss penalty can be reduced by
    using more features, or using a lower L1 coefficient. If your expansion factor is 2, then a good
    starting point for the L1 coefficient is 1e-3.
    """

l1_coefficient: Parameter[float] = field(default=Parameter(0.001)) class-attribute instance-attribute ¤

L1 Penalty Coefficient.

The L1 penalty is the absolute sum of learned (hidden) activations, multiplied by this constant. The penalty encourages sparsity in the learned activations. This loss penalty can be reduced by using more features, or using a lower L1 coefficient. If your expansion factor is 2, then a good starting point for the L1 coefficient is 1e-3.

LossRuntimeHyperparameters ¤

Bases: TypedDict

Loss runtime hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
107
108
109
110
class LossRuntimeHyperparameters(TypedDict):
    """Loss runtime hyperparameters."""

    l1_coefficient: float

OptimizerHyperparameters dataclass ¤

Bases: NestedParameter

Optimizer hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
@dataclass(frozen=True)
class OptimizerHyperparameters(NestedParameter):
    """Optimizer hyperparameters."""

    lr: Parameter[float] = field(default=Parameter(1e-3))
    """Learning rate.

    A good starting point for the learning rate is 1e-3, but this is one of the key parameters so
    you should probably tune it.
    """

    adam_beta_1: Parameter[float] = field(default=Parameter(0.9))
    """Adam Beta 1.

    The exponential decay rate for the first moment estimates (mean) of the gradient.
    """

    adam_beta_2: Parameter[float] = field(default=Parameter(0.99))
    """Adam Beta 2.

    The exponential decay rate for the second moment estimates (variance) of the gradient.
    """

    adam_weight_decay: Parameter[float] = field(default=Parameter(0.0))
    """Adam Weight Decay.

    Weight decay (L2 penalty).
    """

    amsgrad: Parameter[bool] = field(default=Parameter(value=False))
    """AMSGrad.

    Whether to use the AMSGrad variant of this algorithm from the paper [On the Convergence of Adam
    and Beyond](https://arxiv.org/abs/1904.09237).
    """

    fused: Parameter[bool] = field(default=Parameter(value=False))
    """Fused.

    Whether to use a fused implementation of the optimizer (may be faster on CUDA).
    """

    lr_scheduler: Parameter[Literal["reduce_on_plateau", "cosine_annealing"]] | None = field(
        default=None
    )
    """Learning rate scheduler."""

adam_beta_1: Parameter[float] = field(default=Parameter(0.9)) class-attribute instance-attribute ¤

Adam Beta 1.

The exponential decay rate for the first moment estimates (mean) of the gradient.

adam_beta_2: Parameter[float] = field(default=Parameter(0.99)) class-attribute instance-attribute ¤

Adam Beta 2.

The exponential decay rate for the second moment estimates (variance) of the gradient.

adam_weight_decay: Parameter[float] = field(default=Parameter(0.0)) class-attribute instance-attribute ¤

Adam Weight Decay.

Weight decay (L2 penalty).

amsgrad: Parameter[bool] = field(default=Parameter(value=False)) class-attribute instance-attribute ¤

AMSGrad.

Whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond.

fused: Parameter[bool] = field(default=Parameter(value=False)) class-attribute instance-attribute ¤

Fused.

Whether to use a fused implementation of the optimizer (may be faster on CUDA).

lr: Parameter[float] = field(default=Parameter(0.001)) class-attribute instance-attribute ¤

Learning rate.

A good starting point for the learning rate is 1e-3, but this is one of the key parameters so you should probably tune it.

lr_scheduler: Parameter[Literal['reduce_on_plateau', 'cosine_annealing']] | None = field(default=None) class-attribute instance-attribute ¤

Learning rate scheduler.

OptimizerRuntimeHyperparameters ¤

Bases: TypedDict

Optimizer runtime hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
161
162
163
164
165
166
167
168
169
170
class OptimizerRuntimeHyperparameters(TypedDict):
    """Optimizer runtime hyperparameters."""

    lr: float
    adam_beta_1: float
    adam_beta_2: float
    adam_weight_decay: float
    amsgrad: bool
    fused: bool
    lr_scheduler: str | None

PipelineHyperparameters dataclass ¤

Bases: NestedParameter

Pipeline hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
@dataclass(frozen=True)
class PipelineHyperparameters(NestedParameter):
    """Pipeline hyperparameters."""

    log_frequency: Parameter[int] = field(default=Parameter(100))
    """Training log frequency."""

    source_data_batch_size: Parameter[int] = field(default=Parameter(DEFAULT_SOURCE_BATCH_SIZE))
    """Source data batch size."""

    train_batch_size: Parameter[int] = field(default=Parameter(DEFAULT_BATCH_SIZE))
    """Train batch size."""

    max_store_size: Parameter[int] = field(default=Parameter(DEFAULT_STORE_SIZE))
    """Max store size."""

    max_activations: Parameter[int] = field(
        default=Parameter(round_to_multiple(2e9, DEFAULT_STORE_SIZE))
    )
    """Max activations."""

    num_workers_data_loading: Parameter[int] = field(default=Parameter(0))
    """Number of CPU workers for data loading."""

    checkpoint_frequency: Parameter[int] = field(
        default=Parameter(round_to_multiple(5e7, DEFAULT_STORE_SIZE))
    )
    """Checkpoint frequency."""

    validation_frequency: Parameter[int] = field(
        default=Parameter(round_to_multiple(1e8, DEFAULT_BATCH_SIZE))
    )
    """Validation frequency."""

    validation_n_activations: Parameter[int] = field(
        default=Parameter(DEFAULT_SOURCE_BATCH_SIZE * DEFAULT_SOURCE_CONTEXT_SIZE * 2)
    )
    """Number of activations to use for validation."""

checkpoint_frequency: Parameter[int] = field(default=Parameter(round_to_multiple(50000000.0, DEFAULT_STORE_SIZE))) class-attribute instance-attribute ¤

Checkpoint frequency.

log_frequency: Parameter[int] = field(default=Parameter(100)) class-attribute instance-attribute ¤

Training log frequency.

max_activations: Parameter[int] = field(default=Parameter(round_to_multiple(2000000000.0, DEFAULT_STORE_SIZE))) class-attribute instance-attribute ¤

Max activations.

max_store_size: Parameter[int] = field(default=Parameter(DEFAULT_STORE_SIZE)) class-attribute instance-attribute ¤

Max store size.

num_workers_data_loading: Parameter[int] = field(default=Parameter(0)) class-attribute instance-attribute ¤

Number of CPU workers for data loading.

source_data_batch_size: Parameter[int] = field(default=Parameter(DEFAULT_SOURCE_BATCH_SIZE)) class-attribute instance-attribute ¤

Source data batch size.

train_batch_size: Parameter[int] = field(default=Parameter(DEFAULT_BATCH_SIZE)) class-attribute instance-attribute ¤

Train batch size.

validation_frequency: Parameter[int] = field(default=Parameter(round_to_multiple(100000000.0, DEFAULT_BATCH_SIZE))) class-attribute instance-attribute ¤

Validation frequency.

validation_n_activations: Parameter[int] = field(default=Parameter(DEFAULT_SOURCE_BATCH_SIZE * DEFAULT_SOURCE_CONTEXT_SIZE * 2)) class-attribute instance-attribute ¤

Number of activations to use for validation.

PipelineRuntimeHyperparameters ¤

Bases: TypedDict

Pipeline runtime hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
298
299
300
301
302
303
304
305
306
307
308
309
class PipelineRuntimeHyperparameters(TypedDict):
    """Pipeline runtime hyperparameters."""

    log_frequency: int
    source_data_batch_size: int
    train_batch_size: int
    max_store_size: int
    max_activations: int
    num_workers_data_loading: int
    checkpoint_frequency: int
    validation_frequency: int
    validation_n_activations: int

RuntimeHyperparameters ¤

Bases: TypedDict

Runtime hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
384
385
386
387
388
389
390
391
392
393
394
class RuntimeHyperparameters(TypedDict):
    """Runtime hyperparameters."""

    source_data: SourceDataRuntimeHyperparameters
    source_model: SourceModelRuntimeHyperparameters
    activation_resampler: ActivationResamplerRuntimeHyperparameters
    autoencoder: AutoencoderRuntimeHyperparameters
    loss: LossRuntimeHyperparameters
    optimizer: OptimizerRuntimeHyperparameters
    pipeline: PipelineRuntimeHyperparameters
    random_seed: int

SourceDataHyperparameters dataclass ¤

Bases: NestedParameter

Source data hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
@dataclass(frozen=True)
class SourceDataHyperparameters(NestedParameter):
    """Source data hyperparameters."""

    dataset_path: Parameter[str]
    """Dataset path."""

    context_size: Parameter[int] = field(default=Parameter(DEFAULT_SOURCE_CONTEXT_SIZE))
    """Context size."""

    dataset_column_name: Parameter[str] | None = field(default=Parameter(value="input_ids"))
    """Dataset column name."""

    dataset_dir: Parameter[str] | None = field(default=None)
    """Dataset directory (within the HF dataset)"""

    dataset_files: Parameter[list[str]] | None = field(default=None)
    """Dataset files (within the HF dataset)."""

    pre_download: Parameter[bool] = field(default=Parameter(value=False))
    """Whether to pre-download the dataset."""

    pre_tokenized: Parameter[bool] = field(default=Parameter(value=True))
    """If the dataset is pre-tokenized."""

    tokenizer_name: Parameter[str] | None = field(default=None)
    """Tokenizer name.

    Only set this if the dataset is not pre-tokenized.
    """

    def __post_init__(self) -> None:
        """Post initialisation checks.

        Raises:
            ValueError: If there is an error in the source data hyperparameters.
        """
        if self.pre_tokenized.value is False and not isinstance(self.tokenizer_name, Parameter):
            error_message = "The tokenizer name must be specified, when `pre_tokenized` is False."
            raise ValueError(error_message)

        if self.pre_tokenized.value is True and isinstance(self.tokenizer_name, Parameter):
            error_message = "The tokenizer name must not be set, when `pre_tokenized` is True."
            raise ValueError(error_message)

context_size: Parameter[int] = field(default=Parameter(DEFAULT_SOURCE_CONTEXT_SIZE)) class-attribute instance-attribute ¤

Context size.

dataset_column_name: Parameter[str] | None = field(default=Parameter(value='input_ids')) class-attribute instance-attribute ¤

Dataset column name.

dataset_dir: Parameter[str] | None = field(default=None) class-attribute instance-attribute ¤

Dataset directory (within the HF dataset)

dataset_files: Parameter[list[str]] | None = field(default=None) class-attribute instance-attribute ¤

Dataset files (within the HF dataset).

dataset_path: Parameter[str] instance-attribute ¤

Dataset path.

pre_download: Parameter[bool] = field(default=Parameter(value=False)) class-attribute instance-attribute ¤

Whether to pre-download the dataset.

pre_tokenized: Parameter[bool] = field(default=Parameter(value=True)) class-attribute instance-attribute ¤

If the dataset is pre-tokenized.

tokenizer_name: Parameter[str] | None = field(default=None) class-attribute instance-attribute ¤

Tokenizer name.

Only set this if the dataset is not pre-tokenized.

__post_init__() ¤

Post initialisation checks.

Raises:

Type Description
ValueError

If there is an error in the source data hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
204
205
206
207
208
209
210
211
212
213
214
215
216
def __post_init__(self) -> None:
    """Post initialisation checks.

    Raises:
        ValueError: If there is an error in the source data hyperparameters.
    """
    if self.pre_tokenized.value is False and not isinstance(self.tokenizer_name, Parameter):
        error_message = "The tokenizer name must be specified, when `pre_tokenized` is False."
        raise ValueError(error_message)

    if self.pre_tokenized.value is True and isinstance(self.tokenizer_name, Parameter):
        error_message = "The tokenizer name must not be set, when `pre_tokenized` is True."
        raise ValueError(error_message)

SourceDataRuntimeHyperparameters ¤

Bases: TypedDict

Source data runtime hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
219
220
221
222
223
224
225
226
227
228
229
class SourceDataRuntimeHyperparameters(TypedDict):
    """Source data runtime hyperparameters."""

    context_size: int
    dataset_column_name: str
    dataset_dir: str | None
    dataset_files: list[str] | None
    dataset_path: str
    pre_download: bool
    pre_tokenized: bool
    tokenizer_name: str | None

SourceModelHyperparameters dataclass ¤

Bases: NestedParameter

Source model hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
@dataclass(frozen=True)
class SourceModelHyperparameters(NestedParameter):
    """Source model hyperparameters."""

    name: Parameter[str]
    """Source model name."""

    cache_names: Parameter[list[str]]
    """Source model hook site."""

    hook_dimension: Parameter[int]
    """Source model hook point dimension."""

    dtype: Parameter[str] = field(default=Parameter("float32"))
    """Source model dtype."""

cache_names: Parameter[list[str]] instance-attribute ¤

Source model hook site.

dtype: Parameter[str] = field(default=Parameter('float32')) class-attribute instance-attribute ¤

Source model dtype.

hook_dimension: Parameter[int] instance-attribute ¤

Source model hook point dimension.

name: Parameter[str] instance-attribute ¤

Source model name.

SourceModelRuntimeHyperparameters ¤

Bases: TypedDict

Source model runtime hyperparameters.

Source code in sparse_autoencoder/train/sweep_config.py
249
250
251
252
253
254
255
class SourceModelRuntimeHyperparameters(TypedDict):
    """Source model runtime hyperparameters."""

    name: str
    cache_names: list[str]
    hook_dimension: int
    dtype: str

SweepConfig dataclass ¤

Bases: WandbSweepConfig

Sweep Config.

Source code in sparse_autoencoder/train/sweep_config.py
373
374
375
376
377
378
379
380
381
@dataclass
class SweepConfig(WandbSweepConfig):
    """Sweep Config."""

    parameters: Hyperparameters

    method: Method = Method.GRID

    metric: Metric = field(default=Metric(name="train/loss/total_loss"))