Skip to content

Linear encoder layer¤

Linear encoder layer.

LinearEncoder ¤

Bases: Module

Linear encoder layer.

Linear encoder layer (essentially nn.Linear, with a ReLU activation function). Designed to be used as the encoder in a sparse autoencoder (excluding any outer tied bias).

\[ \begin{align*} m &= \text{learned features dimension} \\ n &= \text{input and output dimension} \\ b &= \text{batch items dimension} \\ \overline{\mathbf{x}} \in \mathbb{R}^{b \times n} &= \text{input after tied bias} \\ W_e \in \mathbb{R}^{m \times n} &= \text{weight matrix} \\ b_e \in \mathbb{R}^{m} &= \text{bias vector} \\ f &= \text{ReLU}(\overline{\mathbf{x}} W_e^T + b_e) = \text{LinearEncoder output} \end{align*} \]
Source code in sparse_autoencoder/autoencoder/components/linear_encoder.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
@final
class LinearEncoder(Module):
    r"""Linear encoder layer.

    Linear encoder layer (essentially `nn.Linear`, with a ReLU activation function). Designed to be
    used as the encoder in a sparse autoencoder (excluding any outer tied bias).

    $$
    \begin{align*}
        m &= \text{learned features dimension} \\
        n &= \text{input and output dimension} \\
        b &= \text{batch items dimension} \\
        \overline{\mathbf{x}} \in \mathbb{R}^{b \times n} &= \text{input after tied bias} \\
        W_e \in \mathbb{R}^{m \times n} &= \text{weight matrix} \\
        b_e \in \mathbb{R}^{m} &= \text{bias vector} \\
        f &= \text{ReLU}(\overline{\mathbf{x}} W_e^T + b_e) = \text{LinearEncoder output}
    \end{align*}
    $$
    """

    _learnt_features: int
    """Number of learnt features (inputs to this layer)."""

    _input_features: int
    """Number of input features from the source model."""

    _n_components: int | None

    weight: Float[
        Parameter,
        Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE),
    ]
    """Weight parameter.

    Each row in the weights matrix acts as a dictionary vector, representing a single basis
    element in the learned activation space.
    """

    bias: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
    """Bias parameter."""

    @property
    def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
        """Reset optimizer parameter details.

        Details of the parameters that should be reset in the optimizer, when resetting
        dictionary vectors.

        Returns:
            List of tuples of the form `(parameter, axis)`, where `parameter` is the parameter to
            reset (e.g. encoder.weight), and `axis` is the axis of the parameter to reset.
        """
        return [
            ResetOptimizerParameterDetails(parameter=self.weight, axis=-2),
            ResetOptimizerParameterDetails(parameter=self.bias, axis=-1),
        ]

    activation_function: ReLU
    """Activation function."""

    @validate_call
    def __init__(
        self,
        input_features: PositiveInt,
        learnt_features: PositiveInt,
        n_components: PositiveInt | None,
    ):
        """Initialize the linear encoder layer.

        Args:
            input_features: Number of input features to the autoencoder.
            learnt_features: Number of learnt features in the autoencoder.
            n_components: Number of source model components the SAE is trained on.
        """
        super().__init__()

        self._learnt_features = learnt_features
        self._input_features = input_features
        self._n_components = n_components

        self.weight = Parameter(
            torch.empty(
                shape_with_optional_dimensions(n_components, learnt_features, input_features),
            )
        )
        self.bias = Parameter(
            torch.zeros(shape_with_optional_dimensions(n_components, learnt_features))
        )
        self.activation_function = ReLU()

        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize or reset the parameters."""
        # Assumes we are using ReLU activation function (for e.g. leaky ReLU, the `a` parameter and
        # `nonlinerity` must be changed.
        init.kaiming_uniform_(self.weight, nonlinearity="relu")

        # Bias (approach from nn.Linear)
        fan_in = self.weight.size(1)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        init.uniform_(self.bias, -bound, bound)

    def forward(
        self,
        x: Float[
            Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
        ],
    ) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]:
        """Forward pass.

        Args:
            x: Input tensor.

        Returns:
            Output of the forward pass.
        """
        z = (
            einops.einsum(
                x,
                self.weight,
                f"{Axis.BATCH} ... {Axis.INPUT_OUTPUT_FEATURE}, \
                    ... {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE} \
                    -> {Axis.BATCH} ... {Axis.LEARNT_FEATURE}",
            )
            + self.bias
        )

        return self.activation_function(z)

    @final
    def update_dictionary_vectors(
        self,
        dictionary_vector_indices: Int64[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)],
        updated_dictionary_weights: Float[
            Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX, Axis.INPUT_OUTPUT_FEATURE)
        ],
        component_idx: int | None = None,
    ) -> None:
        """Update encoder dictionary vectors.

        Updates the dictionary vectors (columns in the weight matrix) with the given values.

        Args:
            dictionary_vector_indices: Indices of the dictionary vectors to update.
            updated_dictionary_weights: Updated weights for just these dictionary vectors.
            component_idx: Component index to update.

        Raises:
            ValueError: If there are multiple components and `component_idx` is not specified.
        """
        if dictionary_vector_indices.numel() == 0:
            return

        with torch.no_grad():
            if component_idx is None:
                if self._n_components is not None:
                    error_message = "component_idx must be specified when n_components is not None"
                    raise ValueError(error_message)

                self.weight[dictionary_vector_indices] = updated_dictionary_weights
            else:
                self.weight[component_idx, dictionary_vector_indices] = updated_dictionary_weights

    @final
    def update_bias(
        self,
        update_parameter_indices: Int64[
            Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX)
        ],
        updated_bias_features: Float[
            Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX)
        ],
        component_idx: int | None = None,
    ) -> None:
        """Update encoder bias.

        Args:
            update_parameter_indices: Indices of the bias features to update.
            updated_bias_features: Updated bias features for just these indices.
            component_idx: Component index to update.

        Raises:
            ValueError: If there are multiple components and `component_idx` is not specified.
        """
        if update_parameter_indices.numel() == 0:
            return

        with torch.no_grad():
            if component_idx is None:
                if self._n_components is not None:
                    error_message = "component_idx must be specified when n_components is not None"
                    raise ValueError(error_message)

                self.bias[update_parameter_indices] = updated_bias_features
            else:
                self.bias[component_idx, update_parameter_indices] = updated_bias_features

    def extra_repr(self) -> str:
        """String extra representation of the module."""
        return (
            f"input_features={self._input_features}, "
            f"learnt_features={self._learnt_features}, "
            f"n_components={self._n_components}"
        )

activation_function: ReLU = ReLU() instance-attribute ¤

Activation function.

bias: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)] = Parameter(torch.zeros(shape_with_optional_dimensions(n_components, learnt_features))) instance-attribute ¤

Bias parameter.

reset_optimizer_parameter_details: list[ResetOptimizerParameterDetails] property ¤

Reset optimizer parameter details.

Details of the parameters that should be reset in the optimizer, when resetting dictionary vectors.

Returns:

Type Description
list[ResetOptimizerParameterDetails]

List of tuples of the form (parameter, axis), where parameter is the parameter to

list[ResetOptimizerParameterDetails]

reset (e.g. encoder.weight), and axis is the axis of the parameter to reset.

weight: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)] = Parameter(torch.empty(shape_with_optional_dimensions(n_components, learnt_features, input_features))) instance-attribute ¤

Weight parameter.

Each row in the weights matrix acts as a dictionary vector, representing a single basis element in the learned activation space.

__init__(input_features, learnt_features, n_components) ¤

Initialize the linear encoder layer.

Parameters:

Name Type Description Default
input_features PositiveInt

Number of input features to the autoencoder.

required
learnt_features PositiveInt

Number of learnt features in the autoencoder.

required
n_components PositiveInt | None

Number of source model components the SAE is trained on.

required
Source code in sparse_autoencoder/autoencoder/components/linear_encoder.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@validate_call
def __init__(
    self,
    input_features: PositiveInt,
    learnt_features: PositiveInt,
    n_components: PositiveInt | None,
):
    """Initialize the linear encoder layer.

    Args:
        input_features: Number of input features to the autoencoder.
        learnt_features: Number of learnt features in the autoencoder.
        n_components: Number of source model components the SAE is trained on.
    """
    super().__init__()

    self._learnt_features = learnt_features
    self._input_features = input_features
    self._n_components = n_components

    self.weight = Parameter(
        torch.empty(
            shape_with_optional_dimensions(n_components, learnt_features, input_features),
        )
    )
    self.bias = Parameter(
        torch.zeros(shape_with_optional_dimensions(n_components, learnt_features))
    )
    self.activation_function = ReLU()

    self.reset_parameters()

extra_repr() ¤

String extra representation of the module.

Source code in sparse_autoencoder/autoencoder/components/linear_encoder.py
215
216
217
218
219
220
221
def extra_repr(self) -> str:
    """String extra representation of the module."""
    return (
        f"input_features={self._input_features}, "
        f"learnt_features={self._learnt_features}, "
        f"n_components={self._n_components}"
    )

forward(x) ¤

Forward pass.

Parameters:

Name Type Description Default
x Float[Tensor, names(BATCH, COMPONENT_OPTIONAL, INPUT_OUTPUT_FEATURE)]

Input tensor.

required

Returns:

Type Description
Float[Tensor, names(BATCH, COMPONENT_OPTIONAL, LEARNT_FEATURE)]

Output of the forward pass.

Source code in sparse_autoencoder/autoencoder/components/linear_encoder.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def forward(
    self,
    x: Float[
        Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
    ],
) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]:
    """Forward pass.

    Args:
        x: Input tensor.

    Returns:
        Output of the forward pass.
    """
    z = (
        einops.einsum(
            x,
            self.weight,
            f"{Axis.BATCH} ... {Axis.INPUT_OUTPUT_FEATURE}, \
                ... {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE} \
                -> {Axis.BATCH} ... {Axis.LEARNT_FEATURE}",
        )
        + self.bias
    )

    return self.activation_function(z)

reset_parameters() ¤

Initialize or reset the parameters.

Source code in sparse_autoencoder/autoencoder/components/linear_encoder.py
109
110
111
112
113
114
115
116
117
118
def reset_parameters(self) -> None:
    """Initialize or reset the parameters."""
    # Assumes we are using ReLU activation function (for e.g. leaky ReLU, the `a` parameter and
    # `nonlinerity` must be changed.
    init.kaiming_uniform_(self.weight, nonlinearity="relu")

    # Bias (approach from nn.Linear)
    fan_in = self.weight.size(1)
    bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
    init.uniform_(self.bias, -bound, bound)

update_bias(update_parameter_indices, updated_bias_features, component_idx=None) ¤

Update encoder bias.

Parameters:

Name Type Description Default
update_parameter_indices Int64[Tensor, names(COMPONENT_OPTIONAL, LEARNT_FEATURE_IDX)]

Indices of the bias features to update.

required
updated_bias_features Float[Tensor, names(COMPONENT_OPTIONAL, LEARNT_FEATURE_IDX)]

Updated bias features for just these indices.

required
component_idx int | None

Component index to update.

None

Raises:

Type Description
ValueError

If there are multiple components and component_idx is not specified.

Source code in sparse_autoencoder/autoencoder/components/linear_encoder.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
@final
def update_bias(
    self,
    update_parameter_indices: Int64[
        Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX)
    ],
    updated_bias_features: Float[
        Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX)
    ],
    component_idx: int | None = None,
) -> None:
    """Update encoder bias.

    Args:
        update_parameter_indices: Indices of the bias features to update.
        updated_bias_features: Updated bias features for just these indices.
        component_idx: Component index to update.

    Raises:
        ValueError: If there are multiple components and `component_idx` is not specified.
    """
    if update_parameter_indices.numel() == 0:
        return

    with torch.no_grad():
        if component_idx is None:
            if self._n_components is not None:
                error_message = "component_idx must be specified when n_components is not None"
                raise ValueError(error_message)

            self.bias[update_parameter_indices] = updated_bias_features
        else:
            self.bias[component_idx, update_parameter_indices] = updated_bias_features

update_dictionary_vectors(dictionary_vector_indices, updated_dictionary_weights, component_idx=None) ¤

Update encoder dictionary vectors.

Updates the dictionary vectors (columns in the weight matrix) with the given values.

Parameters:

Name Type Description Default
dictionary_vector_indices Int64[Tensor, names(LEARNT_FEATURE_IDX)]

Indices of the dictionary vectors to update.

required
updated_dictionary_weights Float[Tensor, names(LEARNT_FEATURE_IDX, INPUT_OUTPUT_FEATURE)]

Updated weights for just these dictionary vectors.

required
component_idx int | None

Component index to update.

None

Raises:

Type Description
ValueError

If there are multiple components and component_idx is not specified.

Source code in sparse_autoencoder/autoencoder/components/linear_encoder.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
@final
def update_dictionary_vectors(
    self,
    dictionary_vector_indices: Int64[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)],
    updated_dictionary_weights: Float[
        Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX, Axis.INPUT_OUTPUT_FEATURE)
    ],
    component_idx: int | None = None,
) -> None:
    """Update encoder dictionary vectors.

    Updates the dictionary vectors (columns in the weight matrix) with the given values.

    Args:
        dictionary_vector_indices: Indices of the dictionary vectors to update.
        updated_dictionary_weights: Updated weights for just these dictionary vectors.
        component_idx: Component index to update.

    Raises:
        ValueError: If there are multiple components and `component_idx` is not specified.
    """
    if dictionary_vector_indices.numel() == 0:
        return

    with torch.no_grad():
        if component_idx is None:
            if self._n_components is not None:
                error_message = "component_idx must be specified when n_components is not None"
                raise ValueError(error_message)

            self.weight[dictionary_vector_indices] = updated_dictionary_weights
        else:
            self.weight[component_idx, dictionary_vector_indices] = updated_dictionary_weights