Skip to content

PyTorch Lightning module for training a sparse autoencoder¤

PyTorch Lightning module for training a sparse autoencoder.

LitSparseAutoencoder ¤

Bases: LightningModule

Lightning Sparse Autoencoder.

Source code in sparse_autoencoder/autoencoder/lightning.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class LitSparseAutoencoder(LightningModule):
    """Lightning Sparse Autoencoder."""

    sparse_autoencoder: SparseAutoencoder

    config: LitSparseAutoencoderConfig

    loss_fn: SparseAutoencoderLoss

    train_metrics: MetricCollection

    def __init__(
        self,
        config: LitSparseAutoencoderConfig,
    ):
        """Initialise the module."""
        super().__init__()
        self.sparse_autoencoder = SparseAutoencoder(config)
        self.config = config

        num_components = config.n_components or 1
        add_component_names = partial(
            ClasswiseWrapperWithMean, component_names=config.component_names
        )

        # Create the loss & metrics
        self.loss_fn = SparseAutoencoderLoss(
            num_components, config.l1_coefficient, keep_batch_dim=True
        )

        self.train_metrics = MetricCollection(
            {
                "l0": add_component_names(L0NormMetric(num_components), prefix="train/l0_norm"),
                "activity": add_component_names(
                    NeuronActivityMetric(config.n_learned_features, num_components),
                    prefix="train/neuron_activity",
                ),
                "l1": add_component_names(
                    L1AbsoluteLoss(num_components), prefix="loss/l1_learned_activations"
                ),
                "l2": add_component_names(
                    L2ReconstructionLoss(num_components), prefix="loss/l2_reconstruction"
                ),
                "loss": add_component_names(
                    SparseAutoencoderLoss(num_components, config.l1_coefficient),
                    prefix="loss/total",
                ),
            },
            # Share state & updates across groups (to avoid e.g. computing l1 twice for both the
            # loss and l1 metrics). Note the metric that goes first must calculate all the states
            # needed by the rest of the group.
            compute_groups=[
                ["loss", "l1", "l2"],
                ["activity"],
                ["l0"],
            ],
        )

        self.activation_resampler = ActivationResampler(
            n_learned_features=config.n_learned_features,
            n_components=num_components,
            resample_interval=config.resample_interval,
            max_n_resamples=config.max_n_resamples,
            n_activations_activity_collate=config.resample_dead_neurons_dataset_size,
            resample_dataset_size=config.resample_loss_dataset_size,
            threshold_is_dead_portion_fires=config.resample_threshold_is_dead_portion_fires,
        )

    def forward(  # type: ignore[override]
        self,
        inputs: Float[
            Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
        ],
    ) -> ForwardPassResult:
        """Forward pass."""
        return self.sparse_autoencoder.forward(inputs)

    def update_parameters(self, parameter_updates: list[ParameterUpdateResults]) -> None:
        """Update the parameters of the model from the results of the resampler.

        Args:
            parameter_updates: Parameter updates from the resampler.

        Raises:
            TypeError: If the optimizer is not an AdamWithReset.
        """
        for component_idx, component_parameter_update in enumerate(parameter_updates):
            # Update the weights and biases
            self.sparse_autoencoder.encoder.update_dictionary_vectors(
                component_parameter_update.dead_neuron_indices,
                component_parameter_update.dead_encoder_weight_updates,
                component_idx=component_idx,
            )
            self.sparse_autoencoder.encoder.update_bias(
                component_parameter_update.dead_neuron_indices,
                component_parameter_update.dead_encoder_bias_updates,
                component_idx=component_idx,
            )
            self.sparse_autoencoder.decoder.update_dictionary_vectors(
                component_parameter_update.dead_neuron_indices,
                component_parameter_update.dead_decoder_weight_updates,
                component_idx=component_idx,
            )

            # Reset the optimizer
            for (
                parameter,
                axis,
            ) in self.reset_optimizer_parameter_details:
                optimizer = self.optimizers(use_pl_optimizer=False)
                if not isinstance(optimizer, AdamWithReset):
                    error_message = "Cannot reset the optimizer. "
                    raise TypeError(error_message)

                optimizer.reset_neurons_state(
                    parameter=parameter,
                    neuron_indices=component_parameter_update.dead_neuron_indices,
                    axis=axis,
                    component_idx=component_idx,
                )

    def training_step(  # type: ignore[override]
        self,
        batch: Float[
            Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
        ],
        batch_idx: int | None = None,  # noqa: ARG002
    ) -> Float[Tensor, Axis.SINGLE_ITEM]:
        """Training step."""
        # Forward pass
        output: ForwardPassResult = self.forward(batch)

        # Metrics & loss
        train_metrics = self.train_metrics.forward(
            source_activations=batch,
            learned_activations=output.learned_activations,
            decoded_activations=output.decoded_activations,
        )

        loss = self.loss_fn.forward(
            source_activations=batch,
            learned_activations=output.learned_activations,
            decoded_activations=output.decoded_activations,
        )

        if wandb.run is not None:
            self.log_dict(train_metrics)

        # Resample dead neurons
        parameter_updates = self.activation_resampler.forward(
            input_activations=batch,
            learned_activations=output.learned_activations,
            loss=loss,
            encoder_weight_reference=self.sparse_autoencoder.encoder.weight,
        )
        if parameter_updates is not None:
            self.update_parameters(parameter_updates)

        # Return the mean loss
        return loss.mean()

    def on_after_backward(self) -> None:
        """After-backward pass hook."""
        self.sparse_autoencoder.post_backwards_hook()

    def configure_optimizers(self) -> Optimizer:
        """Configure the optimizer."""
        return AdamWithReset(
            self.sparse_autoencoder.parameters(),
            named_parameters=self.sparse_autoencoder.named_parameters(),
            has_components_dim=True,
        )

    @property
    def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
        """Reset optimizer parameter details."""
        return self.sparse_autoencoder.reset_optimizer_parameter_details

reset_optimizer_parameter_details: list[ResetOptimizerParameterDetails] property ¤

Reset optimizer parameter details.

__init__(config) ¤

Initialise the module.

Source code in sparse_autoencoder/autoencoder/lightning.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def __init__(
    self,
    config: LitSparseAutoencoderConfig,
):
    """Initialise the module."""
    super().__init__()
    self.sparse_autoencoder = SparseAutoencoder(config)
    self.config = config

    num_components = config.n_components or 1
    add_component_names = partial(
        ClasswiseWrapperWithMean, component_names=config.component_names
    )

    # Create the loss & metrics
    self.loss_fn = SparseAutoencoderLoss(
        num_components, config.l1_coefficient, keep_batch_dim=True
    )

    self.train_metrics = MetricCollection(
        {
            "l0": add_component_names(L0NormMetric(num_components), prefix="train/l0_norm"),
            "activity": add_component_names(
                NeuronActivityMetric(config.n_learned_features, num_components),
                prefix="train/neuron_activity",
            ),
            "l1": add_component_names(
                L1AbsoluteLoss(num_components), prefix="loss/l1_learned_activations"
            ),
            "l2": add_component_names(
                L2ReconstructionLoss(num_components), prefix="loss/l2_reconstruction"
            ),
            "loss": add_component_names(
                SparseAutoencoderLoss(num_components, config.l1_coefficient),
                prefix="loss/total",
            ),
        },
        # Share state & updates across groups (to avoid e.g. computing l1 twice for both the
        # loss and l1 metrics). Note the metric that goes first must calculate all the states
        # needed by the rest of the group.
        compute_groups=[
            ["loss", "l1", "l2"],
            ["activity"],
            ["l0"],
        ],
    )

    self.activation_resampler = ActivationResampler(
        n_learned_features=config.n_learned_features,
        n_components=num_components,
        resample_interval=config.resample_interval,
        max_n_resamples=config.max_n_resamples,
        n_activations_activity_collate=config.resample_dead_neurons_dataset_size,
        resample_dataset_size=config.resample_loss_dataset_size,
        threshold_is_dead_portion_fires=config.resample_threshold_is_dead_portion_fires,
    )

configure_optimizers() ¤

Configure the optimizer.

Source code in sparse_autoencoder/autoencoder/lightning.py
232
233
234
235
236
237
238
def configure_optimizers(self) -> Optimizer:
    """Configure the optimizer."""
    return AdamWithReset(
        self.sparse_autoencoder.parameters(),
        named_parameters=self.sparse_autoencoder.named_parameters(),
        has_components_dim=True,
    )

forward(inputs) ¤

Forward pass.

Source code in sparse_autoencoder/autoencoder/lightning.py
135
136
137
138
139
140
141
142
def forward(  # type: ignore[override]
    self,
    inputs: Float[
        Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
    ],
) -> ForwardPassResult:
    """Forward pass."""
    return self.sparse_autoencoder.forward(inputs)

on_after_backward() ¤

After-backward pass hook.

Source code in sparse_autoencoder/autoencoder/lightning.py
228
229
230
def on_after_backward(self) -> None:
    """After-backward pass hook."""
    self.sparse_autoencoder.post_backwards_hook()

training_step(batch, batch_idx=None) ¤

Training step.

Source code in sparse_autoencoder/autoencoder/lightning.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def training_step(  # type: ignore[override]
    self,
    batch: Float[
        Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
    ],
    batch_idx: int | None = None,  # noqa: ARG002
) -> Float[Tensor, Axis.SINGLE_ITEM]:
    """Training step."""
    # Forward pass
    output: ForwardPassResult = self.forward(batch)

    # Metrics & loss
    train_metrics = self.train_metrics.forward(
        source_activations=batch,
        learned_activations=output.learned_activations,
        decoded_activations=output.decoded_activations,
    )

    loss = self.loss_fn.forward(
        source_activations=batch,
        learned_activations=output.learned_activations,
        decoded_activations=output.decoded_activations,
    )

    if wandb.run is not None:
        self.log_dict(train_metrics)

    # Resample dead neurons
    parameter_updates = self.activation_resampler.forward(
        input_activations=batch,
        learned_activations=output.learned_activations,
        loss=loss,
        encoder_weight_reference=self.sparse_autoencoder.encoder.weight,
    )
    if parameter_updates is not None:
        self.update_parameters(parameter_updates)

    # Return the mean loss
    return loss.mean()

update_parameters(parameter_updates) ¤

Update the parameters of the model from the results of the resampler.

Parameters:

Name Type Description Default
parameter_updates list[ParameterUpdateResults]

Parameter updates from the resampler.

required

Raises:

Type Description
TypeError

If the optimizer is not an AdamWithReset.

Source code in sparse_autoencoder/autoencoder/lightning.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def update_parameters(self, parameter_updates: list[ParameterUpdateResults]) -> None:
    """Update the parameters of the model from the results of the resampler.

    Args:
        parameter_updates: Parameter updates from the resampler.

    Raises:
        TypeError: If the optimizer is not an AdamWithReset.
    """
    for component_idx, component_parameter_update in enumerate(parameter_updates):
        # Update the weights and biases
        self.sparse_autoencoder.encoder.update_dictionary_vectors(
            component_parameter_update.dead_neuron_indices,
            component_parameter_update.dead_encoder_weight_updates,
            component_idx=component_idx,
        )
        self.sparse_autoencoder.encoder.update_bias(
            component_parameter_update.dead_neuron_indices,
            component_parameter_update.dead_encoder_bias_updates,
            component_idx=component_idx,
        )
        self.sparse_autoencoder.decoder.update_dictionary_vectors(
            component_parameter_update.dead_neuron_indices,
            component_parameter_update.dead_decoder_weight_updates,
            component_idx=component_idx,
        )

        # Reset the optimizer
        for (
            parameter,
            axis,
        ) in self.reset_optimizer_parameter_details:
            optimizer = self.optimizers(use_pl_optimizer=False)
            if not isinstance(optimizer, AdamWithReset):
                error_message = "Cannot reset the optimizer. "
                raise TypeError(error_message)

            optimizer.reset_neurons_state(
                parameter=parameter,
                neuron_indices=component_parameter_update.dead_neuron_indices,
                axis=axis,
                component_idx=component_idx,
            )

LitSparseAutoencoderConfig ¤

Bases: SparseAutoencoderConfig

PyTorch Lightning Sparse Autoencoder config.

Source code in sparse_autoencoder/autoencoder/lightning.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class LitSparseAutoencoderConfig(SparseAutoencoderConfig):
    """PyTorch Lightning Sparse Autoencoder config."""

    component_names: list[str]

    l1_coefficient: float = 0.001

    resample_interval: PositiveInt = 200000000

    max_n_resamples: NonNegativeInt = 4

    resample_dead_neurons_dataset_size: PositiveInt = 100000000

    resample_loss_dataset_size: PositiveInt = 819200

    resample_threshold_is_dead_portion_fires: NonNegativeFloat = 0.0

    def model_post_init(self, __context: Any) -> None:  # noqa: ANN401
        """Model post init validation.

        Args:
            __context: Pydantic context.

        Raises:
            ValueError: If the number of component names does not match the number of components.
        """
        if self.n_components and len(self.component_names) != self.n_components:
            error_message = (
                f"Number of component names ({len(self.component_names)}) must match the number of "
                f"components ({self.n_components})"
            )
            raise ValueError(error_message)

model_post_init(__context) ¤

Model post init validation.

Parameters:

Name Type Description Default
__context Any

Pydantic context.

required

Raises:

Type Description
ValueError

If the number of component names does not match the number of components.

Source code in sparse_autoencoder/autoencoder/lightning.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def model_post_init(self, __context: Any) -> None:  # noqa: ANN401
    """Model post init validation.

    Args:
        __context: Pydantic context.

    Raises:
        ValueError: If the number of component names does not match the number of components.
    """
    if self.n_components and len(self.component_names) != self.n_components:
        error_message = (
            f"Number of component names ({len(self.component_names)}) must match the number of "
            f"components ({self.n_components})"
        )
        raise ValueError(error_message)