Skip to content

The Sparse Autoencoder Model¤

The Sparse Autoencoder Model.

ForwardPassResult ¤

Bases: NamedTuple

SAE model forward pass result.

Source code in sparse_autoencoder/autoencoder/model.py
66
67
68
69
70
71
72
73
74
75
class ForwardPassResult(NamedTuple):
    """SAE model forward pass result."""

    learned_activations: Float[
        Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
    ]

    decoded_activations: Float[
        Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
    ]

SparseAutoencoder ¤

Bases: Module

Sparse Autoencoder Model.

Source code in sparse_autoencoder/autoencoder/model.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
class SparseAutoencoder(Module):
    """Sparse Autoencoder Model."""

    config: SparseAutoencoderConfig
    """Model config."""

    geometric_median_dataset: Float[
        Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
    ]
    """Estimated Geometric Median of the Dataset.

    Used for initialising :attr:`tied_bias`.
    """

    tied_bias: Float[
        Parameter, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
    ]
    """Tied Bias Parameter.

    The same bias is used pre-encoder and post-decoder.
    """

    pre_encoder_bias: TiedBias
    """Pre-Encoder Bias."""

    encoder: LinearEncoder
    """Encoder."""

    decoder: UnitNormDecoder
    """Decoder."""

    post_decoder_bias: TiedBias
    """Post-Decoder Bias."""

    def __init__(
        self,
        config: SparseAutoencoderConfig,
        geometric_median_dataset: Float[
            Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
        ]
        | None = None,
    ) -> None:
        """Initialize the Sparse Autoencoder Model.

        Args:
            config: Model config.
            geometric_median_dataset: Estimated geometric median of the dataset.
        """
        super().__init__()

        self.config = config

        # Store the geometric median of the dataset (so that we can reset parameters). This is not a
        # parameter itself (the tied bias parameter is used for that), so gradients are disabled.
        tied_bias_shape = shape_with_optional_dimensions(
            config.n_components, config.n_input_features
        )
        if geometric_median_dataset is not None:
            self.geometric_median_dataset = geometric_median_dataset.clone()
            self.geometric_median_dataset.requires_grad = False
        else:
            self.geometric_median_dataset = torch.zeros(tied_bias_shape)
            self.geometric_median_dataset.requires_grad = False

        # Initialize the tied bias
        self.tied_bias = Parameter(torch.empty(tied_bias_shape))
        self.initialize_tied_parameters()

        # Initialize the components
        self.pre_encoder_bias = TiedBias(self.tied_bias, TiedBiasPosition.PRE_ENCODER)

        self.encoder = LinearEncoder(
            input_features=config.n_input_features,
            learnt_features=config.n_learned_features,
            n_components=config.n_components,
        )

        self.decoder = UnitNormDecoder(
            learnt_features=config.n_learned_features,
            decoded_features=config.n_input_features,
            n_components=config.n_components,
        )

        self.post_decoder_bias = TiedBias(self.tied_bias, TiedBiasPosition.POST_DECODER)

    def forward(
        self,
        x: Float[
            Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
        ],
    ) -> ForwardPassResult:
        """Forward Pass.

        Args:
            x: Input activations (e.g. activations from an MLP layer in a transformer model).

        Returns:
            Tuple of learned activations and decoded activations.
        """
        x = self.pre_encoder_bias(x)
        learned_activations = self.encoder(x)
        x = self.decoder(learned_activations)
        decoded_activations = self.post_decoder_bias(x)

        return ForwardPassResult(learned_activations, decoded_activations)

    def initialize_tied_parameters(self) -> None:
        """Initialize the tied parameters."""
        # The tied bias is initialised as the geometric median of the dataset
        self.tied_bias.data = self.geometric_median_dataset

    def reset_parameters(self) -> None:
        """Reset the parameters."""
        self.initialize_tied_parameters()
        for module in self.network:
            if "reset_parameters" in dir(module):
                module.reset_parameters()

    @property
    def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
        """Reset optimizer parameter details.

        Details of the parameters that should be reset in the optimizer, when resetting
        dictionary vectors.

        Returns:
            List of tuples of the form `(parameter, axis)`, where `parameter` is the parameter to
            reset (e.g. encoder.weight), and `axis` is the axis of the parameter to reset.
        """
        return (
            self.encoder.reset_optimizer_parameter_details
            + self.decoder.reset_optimizer_parameter_details
        )

    def post_backwards_hook(self) -> None:
        """Hook to be called after each learning step.

        This can be used to e.g. constrain weights to unit norm.
        """
        self.decoder.constrain_weights_unit_norm()

    @staticmethod
    @validate_call
    def get_single_component_state_dict(
        state: SparseAutoencoderState, component_idx: NonNegativeInt
    ) -> dict[str, Tensor]:
        """Get the state dict for a single component.

        Args:
            state: Sparse Autoencoder state.
            component_idx: Index of the component to get the state dict for.

        Returns:
            State dict for the component.

        Raises:
            ValueError: If the state dict doesn't contain a components dimension.
        """
        # Check the state has a components dimension
        if state.config.n_components is None:
            error_message = (
                "Trying to load a single component from the state dict, but the state dict "
                "doesn't contain a components dimension."
            )
            raise ValueError(error_message)

        # Return the state dict for the component
        return {key: value[component_idx] for key, value in state.state_dict.items()}

    def save(self, file_path: Path) -> None:
        """Save the model config and state dict to a file.

        Args:
            file_path: Path to save the model to.
        """
        file_path.parent.mkdir(parents=True, exist_ok=True)
        state = SparseAutoencoderState(config=self.config, state_dict=self.state_dict())
        torch.save(state, file_path)

    @staticmethod
    def load(
        file_path: FILE_LIKE,
        component_idx: PositiveInt | None = None,
    ) -> "SparseAutoencoder":
        """Load the model from a file.

        Args:
            file_path: Path to load the model from.
            component_idx: If loading a state dict from a model that has been trained on multiple
                components (e.g. all MLP layers) you may want to to load just one component. In this
                case you can set `component_idx` to the index of the component to load. Note you
                should not set this if you want to load a state dict from a model that has been
                trained on a single component (or if you want to load all components).

        Returns:
            The loaded model.
        """
        # Load the file
        serialized_state = torch.load(file_path, map_location=torch.device("cpu"))
        state = SparseAutoencoderState.model_validate(serialized_state)

        # Initialise the model
        config = SparseAutoencoderConfig(
            n_input_features=state.config.n_input_features,
            n_learned_features=state.config.n_learned_features,
            n_components=state.config.n_components if component_idx is None else None,
        )
        state_dict = (
            SparseAutoencoder.get_single_component_state_dict(state, component_idx)
            if component_idx is not None
            else state.state_dict
        )
        model = SparseAutoencoder(config)
        model.load_state_dict(state_dict)

        return model

    def save_to_wandb(
        self,
        artifact_name: str,
        directory: DirectoryPath = DEFAULT_TMP_DIR,
    ) -> str:
        """Save the model to wandb.

        Args:
            artifact_name: A human-readable name for this artifact, which is how you can identify
                this artifact in the UI or reference it in use_artifact calls. Names can contain
                letters, numbers, underscores, hyphens, and dots. The name must be unique across a
                project. Example: "sweep_name 1e9 activations".
            directory: Directory to save the model to.

        Returns:
            Name of the wandb artifact.

        Raises:
            ValueError: If wandb is not initialised.
        """
        # Save the file
        directory.mkdir(parents=True, exist_ok=True)
        file_name = artifact_name + ".pt"
        file_path = directory / file_name
        self.save(file_path)

        # Upload to wandb
        if wandb.run is None:
            error_message = "Trying to save the model to wandb, but wandb is not initialised."
            raise ValueError(error_message)
        artifact = wandb.Artifact(
            artifact_name,
            type="model",
            description="Sparse Autoencoder model state, created with `sparse_autoencoder`.",
        )
        artifact.add_file(str(file_path), name="sae-model-state.pt")
        artifact.save()
        wandb.log_artifact(artifact)
        artifact.wait()

        return artifact.source_qualified_name

    @staticmethod
    def load_from_wandb(
        wandb_artifact_name: str,
        component_idx: PositiveInt | None = None,
    ) -> "SparseAutoencoder":
        """Load the model from wandb.

        Args:
            wandb_artifact_name: Name of the wandb artifact to load the model from (e.g.
                "username/project/artifact_name:version").
            component_idx: If loading a state dict from a model that has been trained on multiple
                components (e.g. all MLP layers) you may want to to load just one component. In this
                case you can set `component_idx` to the index of the component to load. Note you
                should not set this if you want to load a state dict from a model that has been
                trained on a single component (or if you want to load all components).

        Returns:
            The loaded model.
        """
        api = wandb.Api()
        artifact = api.artifact(wandb_artifact_name, type="model")
        download_path = artifact.download()
        return SparseAutoencoder.load(Path(download_path) / "sae-model-state.pt", component_idx)

    def save_to_hugging_face(
        self,
        file_name: str,
        repo_id: str,
        directory: DirectoryPath = DEFAULT_TMP_DIR,
        hf_access_token: str | None = None,
    ) -> None:
        """Save the model to Hugging Face.

        Args:
            file_name: Name of the file (e.g. "model-something.pt").
            repo_id: ID of the repo to save the model to.
            directory: Directory to save the model to.
            hf_access_token: Hugging Face access token.
        """
        # Save the file
        directory.mkdir(parents=True, exist_ok=True)
        file_path = directory / file_name
        self.save(file_path)

        # Upload to Hugging Face
        api = HfApi(token=hf_access_token)
        api.upload_file(
            path_or_fileobj=file_path,
            path_in_repo=file_name,
            repo_id=repo_id,
            repo_type="model",
        )

    @staticmethod
    def load_from_hugging_face(
        file_name: str,
        repo_id: str,
        component_idx: PositiveInt | None = None,
    ) -> "SparseAutoencoder":
        """Load the model from Hugging Face.

        Args:
            file_name: File name of the .pt state file.
            repo_id: ID of the repo to load the model from.
            component_idx: If loading a state dict from a model that has been trained on multiple
                components (e.g. all MLP layers) you may want to to load just one component. In this
                case you can set `component_idx` to the index of the component to load. Note you
                should not set this if you want to load a state dict from a model that has been
                trained on a single component (or if you want to load all components).

        Returns:
            The loaded model.
        """
        local_file = hf_hub_download(
            repo_id=repo_id,
            repo_type="model",
            filename=file_name,
            revision="main",
        )

        return SparseAutoencoder.load(Path(local_file), component_idx)

config: SparseAutoencoderConfig = config instance-attribute ¤

Model config.

decoder: UnitNormDecoder = UnitNormDecoder(learnt_features=config.n_learned_features, decoded_features=config.n_input_features, n_components=config.n_components) instance-attribute ¤

Decoder.

encoder: LinearEncoder = LinearEncoder(input_features=config.n_input_features, learnt_features=config.n_learned_features, n_components=config.n_components) instance-attribute ¤

Encoder.

geometric_median_dataset: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)] instance-attribute ¤

Estimated Geometric Median of the Dataset.

Used for initialising :attr:tied_bias.

post_decoder_bias: TiedBias = TiedBias(self.tied_bias, TiedBiasPosition.POST_DECODER) instance-attribute ¤

Post-Decoder Bias.

pre_encoder_bias: TiedBias = TiedBias(self.tied_bias, TiedBiasPosition.PRE_ENCODER) instance-attribute ¤

Pre-Encoder Bias.

reset_optimizer_parameter_details: list[ResetOptimizerParameterDetails] property ¤

Reset optimizer parameter details.

Details of the parameters that should be reset in the optimizer, when resetting dictionary vectors.

Returns:

Type Description
list[ResetOptimizerParameterDetails]

List of tuples of the form (parameter, axis), where parameter is the parameter to

list[ResetOptimizerParameterDetails]

reset (e.g. encoder.weight), and axis is the axis of the parameter to reset.

tied_bias: Float[Parameter, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)] = Parameter(torch.empty(tied_bias_shape)) instance-attribute ¤

Tied Bias Parameter.

The same bias is used pre-encoder and post-decoder.

__init__(config, geometric_median_dataset=None) ¤

Initialize the Sparse Autoencoder Model.

Parameters:

Name Type Description Default
config SparseAutoencoderConfig

Model config.

required
geometric_median_dataset Float[Tensor, names(COMPONENT_OPTIONAL, INPUT_OUTPUT_FEATURE)] | None

Estimated geometric median of the dataset.

None
Source code in sparse_autoencoder/autoencoder/model.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def __init__(
    self,
    config: SparseAutoencoderConfig,
    geometric_median_dataset: Float[
        Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
    ]
    | None = None,
) -> None:
    """Initialize the Sparse Autoencoder Model.

    Args:
        config: Model config.
        geometric_median_dataset: Estimated geometric median of the dataset.
    """
    super().__init__()

    self.config = config

    # Store the geometric median of the dataset (so that we can reset parameters). This is not a
    # parameter itself (the tied bias parameter is used for that), so gradients are disabled.
    tied_bias_shape = shape_with_optional_dimensions(
        config.n_components, config.n_input_features
    )
    if geometric_median_dataset is not None:
        self.geometric_median_dataset = geometric_median_dataset.clone()
        self.geometric_median_dataset.requires_grad = False
    else:
        self.geometric_median_dataset = torch.zeros(tied_bias_shape)
        self.geometric_median_dataset.requires_grad = False

    # Initialize the tied bias
    self.tied_bias = Parameter(torch.empty(tied_bias_shape))
    self.initialize_tied_parameters()

    # Initialize the components
    self.pre_encoder_bias = TiedBias(self.tied_bias, TiedBiasPosition.PRE_ENCODER)

    self.encoder = LinearEncoder(
        input_features=config.n_input_features,
        learnt_features=config.n_learned_features,
        n_components=config.n_components,
    )

    self.decoder = UnitNormDecoder(
        learnt_features=config.n_learned_features,
        decoded_features=config.n_input_features,
        n_components=config.n_components,
    )

    self.post_decoder_bias = TiedBias(self.tied_bias, TiedBiasPosition.POST_DECODER)

forward(x) ¤

Forward Pass.

Parameters:

Name Type Description Default
x Float[Tensor, names(BATCH, COMPONENT_OPTIONAL, INPUT_OUTPUT_FEATURE)]

Input activations (e.g. activations from an MLP layer in a transformer model).

required

Returns:

Type Description
ForwardPassResult

Tuple of learned activations and decoded activations.

Source code in sparse_autoencoder/autoencoder/model.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def forward(
    self,
    x: Float[
        Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
    ],
) -> ForwardPassResult:
    """Forward Pass.

    Args:
        x: Input activations (e.g. activations from an MLP layer in a transformer model).

    Returns:
        Tuple of learned activations and decoded activations.
    """
    x = self.pre_encoder_bias(x)
    learned_activations = self.encoder(x)
    x = self.decoder(learned_activations)
    decoded_activations = self.post_decoder_bias(x)

    return ForwardPassResult(learned_activations, decoded_activations)

get_single_component_state_dict(state, component_idx) staticmethod ¤

Get the state dict for a single component.

Parameters:

Name Type Description Default
state SparseAutoencoderState

Sparse Autoencoder state.

required
component_idx NonNegativeInt

Index of the component to get the state dict for.

required

Returns:

Type Description
dict[str, Tensor]

State dict for the component.

Raises:

Type Description
ValueError

If the state dict doesn't contain a components dimension.

Source code in sparse_autoencoder/autoencoder/model.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
@staticmethod
@validate_call
def get_single_component_state_dict(
    state: SparseAutoencoderState, component_idx: NonNegativeInt
) -> dict[str, Tensor]:
    """Get the state dict for a single component.

    Args:
        state: Sparse Autoencoder state.
        component_idx: Index of the component to get the state dict for.

    Returns:
        State dict for the component.

    Raises:
        ValueError: If the state dict doesn't contain a components dimension.
    """
    # Check the state has a components dimension
    if state.config.n_components is None:
        error_message = (
            "Trying to load a single component from the state dict, but the state dict "
            "doesn't contain a components dimension."
        )
        raise ValueError(error_message)

    # Return the state dict for the component
    return {key: value[component_idx] for key, value in state.state_dict.items()}

initialize_tied_parameters() ¤

Initialize the tied parameters.

Source code in sparse_autoencoder/autoencoder/model.py
187
188
189
190
def initialize_tied_parameters(self) -> None:
    """Initialize the tied parameters."""
    # The tied bias is initialised as the geometric median of the dataset
    self.tied_bias.data = self.geometric_median_dataset

load(file_path, component_idx=None) staticmethod ¤

Load the model from a file.

Parameters:

Name Type Description Default
file_path FILE_LIKE

Path to load the model from.

required
component_idx PositiveInt | None

If loading a state dict from a model that has been trained on multiple components (e.g. all MLP layers) you may want to to load just one component. In this case you can set component_idx to the index of the component to load. Note you should not set this if you want to load a state dict from a model that has been trained on a single component (or if you want to load all components).

None

Returns:

Type Description
SparseAutoencoder

The loaded model.

Source code in sparse_autoencoder/autoencoder/model.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
@staticmethod
def load(
    file_path: FILE_LIKE,
    component_idx: PositiveInt | None = None,
) -> "SparseAutoencoder":
    """Load the model from a file.

    Args:
        file_path: Path to load the model from.
        component_idx: If loading a state dict from a model that has been trained on multiple
            components (e.g. all MLP layers) you may want to to load just one component. In this
            case you can set `component_idx` to the index of the component to load. Note you
            should not set this if you want to load a state dict from a model that has been
            trained on a single component (or if you want to load all components).

    Returns:
        The loaded model.
    """
    # Load the file
    serialized_state = torch.load(file_path, map_location=torch.device("cpu"))
    state = SparseAutoencoderState.model_validate(serialized_state)

    # Initialise the model
    config = SparseAutoencoderConfig(
        n_input_features=state.config.n_input_features,
        n_learned_features=state.config.n_learned_features,
        n_components=state.config.n_components if component_idx is None else None,
    )
    state_dict = (
        SparseAutoencoder.get_single_component_state_dict(state, component_idx)
        if component_idx is not None
        else state.state_dict
    )
    model = SparseAutoencoder(config)
    model.load_state_dict(state_dict)

    return model

load_from_hugging_face(file_name, repo_id, component_idx=None) staticmethod ¤

Load the model from Hugging Face.

Parameters:

Name Type Description Default
file_name str

File name of the .pt state file.

required
repo_id str

ID of the repo to load the model from.

required
component_idx PositiveInt | None

If loading a state dict from a model that has been trained on multiple components (e.g. all MLP layers) you may want to to load just one component. In this case you can set component_idx to the index of the component to load. Note you should not set this if you want to load a state dict from a model that has been trained on a single component (or if you want to load all components).

None

Returns:

Type Description
SparseAutoencoder

The loaded model.

Source code in sparse_autoencoder/autoencoder/model.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
@staticmethod
def load_from_hugging_face(
    file_name: str,
    repo_id: str,
    component_idx: PositiveInt | None = None,
) -> "SparseAutoencoder":
    """Load the model from Hugging Face.

    Args:
        file_name: File name of the .pt state file.
        repo_id: ID of the repo to load the model from.
        component_idx: If loading a state dict from a model that has been trained on multiple
            components (e.g. all MLP layers) you may want to to load just one component. In this
            case you can set `component_idx` to the index of the component to load. Note you
            should not set this if you want to load a state dict from a model that has been
            trained on a single component (or if you want to load all components).

    Returns:
        The loaded model.
    """
    local_file = hf_hub_download(
        repo_id=repo_id,
        repo_type="model",
        filename=file_name,
        revision="main",
    )

    return SparseAutoencoder.load(Path(local_file), component_idx)

load_from_wandb(wandb_artifact_name, component_idx=None) staticmethod ¤

Load the model from wandb.

Parameters:

Name Type Description Default
wandb_artifact_name str

Name of the wandb artifact to load the model from (e.g. "username/project/artifact_name:version").

required
component_idx PositiveInt | None

If loading a state dict from a model that has been trained on multiple components (e.g. all MLP layers) you may want to to load just one component. In this case you can set component_idx to the index of the component to load. Note you should not set this if you want to load a state dict from a model that has been trained on a single component (or if you want to load all components).

None

Returns:

Type Description
SparseAutoencoder

The loaded model.

Source code in sparse_autoencoder/autoencoder/model.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
@staticmethod
def load_from_wandb(
    wandb_artifact_name: str,
    component_idx: PositiveInt | None = None,
) -> "SparseAutoencoder":
    """Load the model from wandb.

    Args:
        wandb_artifact_name: Name of the wandb artifact to load the model from (e.g.
            "username/project/artifact_name:version").
        component_idx: If loading a state dict from a model that has been trained on multiple
            components (e.g. all MLP layers) you may want to to load just one component. In this
            case you can set `component_idx` to the index of the component to load. Note you
            should not set this if you want to load a state dict from a model that has been
            trained on a single component (or if you want to load all components).

    Returns:
        The loaded model.
    """
    api = wandb.Api()
    artifact = api.artifact(wandb_artifact_name, type="model")
    download_path = artifact.download()
    return SparseAutoencoder.load(Path(download_path) / "sae-model-state.pt", component_idx)

post_backwards_hook() ¤

Hook to be called after each learning step.

This can be used to e.g. constrain weights to unit norm.

Source code in sparse_autoencoder/autoencoder/model.py
215
216
217
218
219
220
def post_backwards_hook(self) -> None:
    """Hook to be called after each learning step.

    This can be used to e.g. constrain weights to unit norm.
    """
    self.decoder.constrain_weights_unit_norm()

reset_parameters() ¤

Reset the parameters.

Source code in sparse_autoencoder/autoencoder/model.py
192
193
194
195
196
197
def reset_parameters(self) -> None:
    """Reset the parameters."""
    self.initialize_tied_parameters()
    for module in self.network:
        if "reset_parameters" in dir(module):
            module.reset_parameters()

save(file_path) ¤

Save the model config and state dict to a file.

Parameters:

Name Type Description Default
file_path Path

Path to save the model to.

required
Source code in sparse_autoencoder/autoencoder/model.py
250
251
252
253
254
255
256
257
258
def save(self, file_path: Path) -> None:
    """Save the model config and state dict to a file.

    Args:
        file_path: Path to save the model to.
    """
    file_path.parent.mkdir(parents=True, exist_ok=True)
    state = SparseAutoencoderState(config=self.config, state_dict=self.state_dict())
    torch.save(state, file_path)

save_to_hugging_face(file_name, repo_id, directory=DEFAULT_TMP_DIR, hf_access_token=None) ¤

Save the model to Hugging Face.

Parameters:

Name Type Description Default
file_name str

Name of the file (e.g. "model-something.pt").

required
repo_id str

ID of the repo to save the model to.

required
directory DirectoryPath

Directory to save the model to.

DEFAULT_TMP_DIR
hf_access_token str | None

Hugging Face access token.

None
Source code in sparse_autoencoder/autoencoder/model.py
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def save_to_hugging_face(
    self,
    file_name: str,
    repo_id: str,
    directory: DirectoryPath = DEFAULT_TMP_DIR,
    hf_access_token: str | None = None,
) -> None:
    """Save the model to Hugging Face.

    Args:
        file_name: Name of the file (e.g. "model-something.pt").
        repo_id: ID of the repo to save the model to.
        directory: Directory to save the model to.
        hf_access_token: Hugging Face access token.
    """
    # Save the file
    directory.mkdir(parents=True, exist_ok=True)
    file_path = directory / file_name
    self.save(file_path)

    # Upload to Hugging Face
    api = HfApi(token=hf_access_token)
    api.upload_file(
        path_or_fileobj=file_path,
        path_in_repo=file_name,
        repo_id=repo_id,
        repo_type="model",
    )

save_to_wandb(artifact_name, directory=DEFAULT_TMP_DIR) ¤

Save the model to wandb.

Parameters:

Name Type Description Default
artifact_name str

A human-readable name for this artifact, which is how you can identify this artifact in the UI or reference it in use_artifact calls. Names can contain letters, numbers, underscores, hyphens, and dots. The name must be unique across a project. Example: "sweep_name 1e9 activations".

required
directory DirectoryPath

Directory to save the model to.

DEFAULT_TMP_DIR

Returns:

Type Description
str

Name of the wandb artifact.

Raises:

Type Description
ValueError

If wandb is not initialised.

Source code in sparse_autoencoder/autoencoder/model.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def save_to_wandb(
    self,
    artifact_name: str,
    directory: DirectoryPath = DEFAULT_TMP_DIR,
) -> str:
    """Save the model to wandb.

    Args:
        artifact_name: A human-readable name for this artifact, which is how you can identify
            this artifact in the UI or reference it in use_artifact calls. Names can contain
            letters, numbers, underscores, hyphens, and dots. The name must be unique across a
            project. Example: "sweep_name 1e9 activations".
        directory: Directory to save the model to.

    Returns:
        Name of the wandb artifact.

    Raises:
        ValueError: If wandb is not initialised.
    """
    # Save the file
    directory.mkdir(parents=True, exist_ok=True)
    file_name = artifact_name + ".pt"
    file_path = directory / file_name
    self.save(file_path)

    # Upload to wandb
    if wandb.run is None:
        error_message = "Trying to save the model to wandb, but wandb is not initialised."
        raise ValueError(error_message)
    artifact = wandb.Artifact(
        artifact_name,
        type="model",
        description="Sparse Autoencoder model state, created with `sparse_autoencoder`.",
    )
    artifact.add_file(str(file_path), name="sae-model-state.pt")
    artifact.save()
    wandb.log_artifact(artifact)
    artifact.wait()

    return artifact.source_qualified_name

SparseAutoencoderConfig ¤

Bases: BaseModel

SAE model config.

Source code in sparse_autoencoder/autoencoder/model.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class SparseAutoencoderConfig(BaseModel):
    """SAE model config."""

    n_input_features: PositiveInt
    """Number of input features.

    E.g. `d_mlp` if training on MLP activations from TransformerLens).
    """

    n_learned_features: PositiveInt
    """Number of learned features.

    The initial paper experimented with 1 to 256 times the number of input features, and primarily
    used a multiple of 8."""

    n_components: PositiveInt | None = None
    """Number of source model components the SAE is trained on.""

    This is useful if you want to train the SAE on several components of the source model at once.
    If `None`, the SAE is assumed to be trained on just one component (in this case the model won't
    contain a component axis in any of the parameters).
    """

n_components: PositiveInt | None = None class-attribute instance-attribute ¤

Number of source model components the SAE is trained on.""

This is useful if you want to train the SAE on several components of the source model at once. If None, the SAE is assumed to be trained on just one component (in this case the model won't contain a component axis in any of the parameters).

n_input_features: PositiveInt instance-attribute ¤

Number of input features.

E.g. d_mlp if training on MLP activations from TransformerLens).

n_learned_features: PositiveInt instance-attribute ¤

Number of learned features.

The initial paper experimented with 1 to 256 times the number of input features, and primarily used a multiple of 8.

SparseAutoencoderState ¤

Bases: BaseModel

SAE model state.

Used for saving and loading the model.

Source code in sparse_autoencoder/autoencoder/model.py
53
54
55
56
57
58
59
60
61
62
63
class SparseAutoencoderState(BaseModel, arbitrary_types_allowed=True):
    """SAE model state.

    Used for saving and loading the model.
    """

    config: SparseAutoencoderConfig
    """Model config."""

    state_dict: dict[str, Tensor]
    """Model state dict."""

config: SparseAutoencoderConfig instance-attribute ¤

Model config.

state_dict: dict[str, Tensor] instance-attribute ¤

Model state dict.