Skip to content

Linear layer with unit norm weights¤

Linear layer with unit norm weights.

UnitNormDecoder ¤

Bases: Module

Constrained unit norm linear decoder layer.

Linear layer decoder, where the dictionary vectors (columns of the weight matrix) are constrained to have unit norm. This is done by removing the gradient information parallel to the dictionary vectors before applying the gradient step, using a backward hook. It also requires constrain_weights_unit_norm to be called after each gradient step, to prevent drift of the dictionary vectors away from unit norm (as optimisers such as Adam don't strictly follow the gradient, but instead follow a modified gradient that includes momentum).

\[ \begin{align*} m &= \text{learned features dimension} \\ n &= \text{input and output dimension} \\ b &= \text{batch items dimension} \\ f \in \mathbb{R}^{b \times m} &= \text{encoder output} \\ W_d \in \mathbb{R}^{n \times m} &= \text{weight matrix} \\ z \in \mathbb{R}^{b \times m} &= f W_d^T = \text{UnitNormDecoder output (pre-tied bias)} \end{align*} \]
Motivation

Normalisation of the columns (dictionary features) prevents the model from reducing the sparsity loss term by increasing the size of the feature vectors in \(W_d\).

Note that the Towards Monosemanticity: Decomposing Language Models With Dictionary Learning paper found that removing the gradient information parallel to the dictionary vectors before applying the gradient step, rather than resetting the dictionary vectors to unit norm after each gradient step, results in a small but real reduction in total loss](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-optimization).

Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
@final
class UnitNormDecoder(Module):
    r"""Constrained unit norm linear decoder layer.

    Linear layer decoder, where the dictionary vectors (columns of the weight matrix) are
    constrained to have unit norm. This is done by removing the gradient information parallel to the
    dictionary vectors before applying the gradient step, using a backward hook. It also requires
    `constrain_weights_unit_norm` to be called after each gradient step, to prevent drift of the
    dictionary vectors away from unit norm (as optimisers such as Adam don't strictly follow the
    gradient, but instead follow a modified gradient that includes momentum).

    $$ \begin{align*}
        m &= \text{learned features dimension} \\
        n &= \text{input and output dimension} \\
        b &= \text{batch items dimension} \\
        f \in \mathbb{R}^{b \times m} &= \text{encoder output} \\
        W_d \in \mathbb{R}^{n \times m} &= \text{weight matrix} \\
        z \in \mathbb{R}^{b \times m} &= f W_d^T = \text{UnitNormDecoder output (pre-tied bias)}
    \end{align*} $$

    Motivation:
        Normalisation of the columns (dictionary features) prevents the model from reducing the
        sparsity loss term by increasing the size of the feature vectors in $W_d$.

        Note that the *Towards Monosemanticity: Decomposing Language Models With Dictionary
        Learning* paper found that removing the gradient information parallel to the dictionary
        vectors before applying the gradient step, rather than resetting the dictionary vectors to
        unit norm after each gradient step, results in a small but real reduction in total
        loss](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-optimization).
    """

    _learnt_features: int
    """Number of learnt features (inputs to this layer)."""

    _decoded_features: int
    """Number of decoded features (outputs from this layer)."""

    _n_components: int | None

    weight: Float[
        Parameter,
        Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE, Axis.LEARNT_FEATURE),
    ]
    """Weight parameter.

    Each column in the weights matrix acts as a dictionary vector, representing a single basis
    element in the learned activation space.
    """

    @property
    def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
        """Reset optimizer parameter details.

        Details of the parameters that should be reset in the optimizer, when resetting
        dictionary vectors.

        Returns:
            List of tuples of the form `(parameter, axis)`, where `parameter` is the parameter to
            reset (e.g. encoder.weight), and `axis` is the axis of the parameter to reset.
        """
        return [ResetOptimizerParameterDetails(parameter=self.weight, axis=-1)]

    @validate_call
    def __init__(
        self,
        learnt_features: PositiveInt,
        decoded_features: PositiveInt,
        n_components: PositiveInt | None,
        *,
        enable_gradient_hook: bool = True,
    ) -> None:
        """Initialize the constrained unit norm linear layer.

        Args:
            learnt_features: Number of learnt features in the autoencoder.
            decoded_features: Number of decoded (output) features in the autoencoder.
            n_components: Number of source model components the SAE is trained on.
            enable_gradient_hook: Enable the gradient backwards hook (modify the gradient before
                applying the gradient step, to maintain unit norm of the dictionary vectors).
        """
        super().__init__()

        self._learnt_features = learnt_features
        self._decoded_features = decoded_features
        self._n_components = n_components

        # Create the linear layer as per the standard PyTorch linear layer
        self.weight = Parameter(
            torch.empty(
                shape_with_optional_dimensions(n_components, decoded_features, learnt_features),
            )
        )
        self.reset_parameters()

        # Register backward hook to remove any gradient information parallel to the dictionary
        # vectors (columns of the weight matrix) before applying the gradient step.
        if enable_gradient_hook:
            self.weight.register_hook(self._weight_backward_hook)

    def update_dictionary_vectors(
        self,
        dictionary_vector_indices: Int64[
            Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX)
        ],
        updated_weights: Float[
            Tensor,
            Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE, Axis.LEARNT_FEATURE_IDX),
        ],
        component_idx: int | None = None,
    ) -> None:
        """Update decoder dictionary vectors.

        Updates the dictionary vectors (rows in the weight matrix) with the given values. Typically
        this is used when resampling neurons (dictionary vectors) that have died.

        Args:
            dictionary_vector_indices: Indices of the dictionary vectors to update.
            updated_weights: Updated weights for just these dictionary vectors.
            component_idx: Component index to update.

        Raises:
            ValueError: If `component_idx` is not specified when `n_components` is not None.
        """
        if dictionary_vector_indices.numel() == 0:
            return

        with torch.no_grad():
            if component_idx is None:
                if self._n_components is not None:
                    error_message = "component_idx must be specified when n_components is not None"
                    raise ValueError(error_message)

                self.weight[:, dictionary_vector_indices] = updated_weights
            else:
                self.weight[component_idx, :, dictionary_vector_indices] = updated_weights

    def constrain_weights_unit_norm(self) -> None:
        """Constrain the weights to have unit norm.

        Warning:
            Note this must be called after each gradient step. This is because optimisers such as
            Adam don't strictly follow the gradient, but instead follow a modified gradient that
            includes momentum. This means that the gradient step can change the norm of the
            dictionary vectors, even when the hook `_weight_backward_hook` is applied.

            Note this can't be applied directly in the backward hook, as it would interfere with a
            variety of use cases (e.g. gradient accumulation across mini-batches, concurrency issues
            with asynchronous operations, etc).

        Example:
            >>> import torch
            >>> layer = UnitNormDecoder(3, 3, None)
            >>> layer.weight.data = torch.ones((3, 3)) * 10
            >>> layer.constrain_weights_unit_norm()
            >>> column_norms = torch.sqrt(torch.sum(layer.weight ** 2, dim=0))
            >>> column_norms.round(decimals=3).tolist()
            [1.0, 1.0, 1.0]

        """
        with torch.no_grad():
            torch.nn.functional.normalize(self.weight, dim=-2, out=self.weight)

    def reset_parameters(self) -> None:
        """Initialize or reset the parameters.

        Example:
            >>> import torch
            >>> # Create a layer with 4 columns (learnt features) and 3 rows (decoded features)
            >>> layer = UnitNormDecoder(learnt_features=4, decoded_features=3, n_components=None)
            >>> layer.reset_parameters()
            >>> # Get the norm across the rows (by summing across the columns)
            >>> column_norms = torch.sum(layer.weight ** 2, dim=0)
            >>> column_norms.round(decimals=3).tolist()
            [1.0, 1.0, 1.0, 1.0]

        """
        # Initialize the weights with a normal distribution. Note we don't use e.g. kaiming
        # normalisation here, since we immediately scale the weights to have unit norm (so the
        # initial standard deviation doesn't matter). Note also that `init.normal_` is in place.
        self.weight: Float[
            Parameter,
            Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE),
        ] = init.normal_(self.weight, mean=0, std=1)  # type: ignore

        # Scale so that each row has unit norm
        self.constrain_weights_unit_norm()

    def _weight_backward_hook(
        self,
        grad: Float[
            Tensor,
            Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE),
        ],
    ) -> Float[
        Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
    ]:
        r"""Unit norm backward hook.

        By subtracting the projection of the gradient onto the dictionary vectors, we remove the
        component of the gradient that is parallel to the dictionary vectors and just keep the
        component that is orthogonal to the dictionary vectors (i.e. moving around the hypersphere).
        The result is that the backward pass does not change the norm of the dictionary vectors.

        $$
        \begin{align*}
            W_d &\in \mathbb{R}^{n \times m} = \text{Decoder weight matrix} \\
            g &\in \mathbb{R}^{n \times m} = \text{Gradient w.r.t. } W_d
                \text{ from the backward pass} \\
            W_{d, \text{norm}} &= \frac{W_d}{\|W_d\|} = \text{Normalized decoder weight matrix
                (over columns)} \\
            g_{\parallel} &\in \mathbb{R}^{n \times m} = \text{Component of } g
                \text{ parallel to } W_{d, \text{norm}} \\
            g_{\perp} &\in \mathbb{R}^{n \times m} = \text{Component of } g \text{ orthogonal to }
                W_{d, \text{norm}} \\
            g_{\parallel} &= W_{d, \text{norm}} \cdot (W_{d, \text{norm}}^\top \cdot g) \\
            g_{\perp} &= g - g_{\parallel} =
                \text{Adjusted gradient with parallel component removed} \\
        \end{align*}
        $$

        Args:
            grad: Gradient with respect to the weights.

        Returns:
            Gradient with respect to the weights, with the component parallel to the dictionary
            vectors removed.
        """
        # Project the gradients onto the dictionary vectors. Intuitively the dictionary vectors can
        # be thought of as vectors that end on the circumference of a hypersphere. The projection of
        # the gradient onto the dictionary vectors is the component of the gradient that is parallel
        # to the dictionary vectors, i.e. the component that moves to or from the center of the
        # hypersphere.
        normalized_weight: Float[
            Tensor,
            Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE),
        ] = self.weight / torch.norm(self.weight, dim=-2, keepdim=True)

        scalar_projections = einops.einsum(
            grad,
            normalized_weight,
            f"... {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE}, \
                ... {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE} \
                -> ... {Axis.INPUT_OUTPUT_FEATURE}",
        )

        projection = einops.einsum(
            scalar_projections,
            normalized_weight,
            f"... {Axis.INPUT_OUTPUT_FEATURE}, \
                ... {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE} \
                -> ... {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE}",
        )

        # Subtracting the parallel component from the gradient leaves only the component that is
        # orthogonal to the dictionary vectors, i.e. the component that moves around the surface of
        # the hypersphere.
        return grad - projection

    def forward(
        self, x: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
    ) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]:
        """Forward pass.

        Args:
            x: Input tensor.

        Returns:
            Output of the forward pass.
        """
        return einops.einsum(
            x,
            self.weight,
            f"{Axis.BATCH} ... {Axis.LEARNT_FEATURE}, \
            ... {Axis.INPUT_OUTPUT_FEATURE} {Axis.LEARNT_FEATURE} \
                -> {Axis.BATCH} ... {Axis.INPUT_OUTPUT_FEATURE}",
        )

    def extra_repr(self) -> str:
        """String extra representation of the module."""
        return (
            f"learnt_features={self._learnt_features}, "
            f"decoded_features={self._decoded_features}, "
            f"n_components={self._n_components}"
        )

reset_optimizer_parameter_details: list[ResetOptimizerParameterDetails] property ¤

Reset optimizer parameter details.

Details of the parameters that should be reset in the optimizer, when resetting dictionary vectors.

Returns:

Type Description
list[ResetOptimizerParameterDetails]

List of tuples of the form (parameter, axis), where parameter is the parameter to

list[ResetOptimizerParameterDetails]

reset (e.g. encoder.weight), and axis is the axis of the parameter to reset.

weight: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE, Axis.LEARNT_FEATURE)] = Parameter(torch.empty(shape_with_optional_dimensions(n_components, decoded_features, learnt_features))) instance-attribute ¤

Weight parameter.

Each column in the weights matrix acts as a dictionary vector, representing a single basis element in the learned activation space.

__init__(learnt_features, decoded_features, n_components, *, enable_gradient_hook=True) ¤

Initialize the constrained unit norm linear layer.

Parameters:

Name Type Description Default
learnt_features PositiveInt

Number of learnt features in the autoencoder.

required
decoded_features PositiveInt

Number of decoded (output) features in the autoencoder.

required
n_components PositiveInt | None

Number of source model components the SAE is trained on.

required
enable_gradient_hook bool

Enable the gradient backwards hook (modify the gradient before applying the gradient step, to maintain unit norm of the dictionary vectors).

True
Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@validate_call
def __init__(
    self,
    learnt_features: PositiveInt,
    decoded_features: PositiveInt,
    n_components: PositiveInt | None,
    *,
    enable_gradient_hook: bool = True,
) -> None:
    """Initialize the constrained unit norm linear layer.

    Args:
        learnt_features: Number of learnt features in the autoencoder.
        decoded_features: Number of decoded (output) features in the autoencoder.
        n_components: Number of source model components the SAE is trained on.
        enable_gradient_hook: Enable the gradient backwards hook (modify the gradient before
            applying the gradient step, to maintain unit norm of the dictionary vectors).
    """
    super().__init__()

    self._learnt_features = learnt_features
    self._decoded_features = decoded_features
    self._n_components = n_components

    # Create the linear layer as per the standard PyTorch linear layer
    self.weight = Parameter(
        torch.empty(
            shape_with_optional_dimensions(n_components, decoded_features, learnt_features),
        )
    )
    self.reset_parameters()

    # Register backward hook to remove any gradient information parallel to the dictionary
    # vectors (columns of the weight matrix) before applying the gradient step.
    if enable_gradient_hook:
        self.weight.register_hook(self._weight_backward_hook)

constrain_weights_unit_norm() ¤

Constrain the weights to have unit norm.

Warning

Note this must be called after each gradient step. This is because optimisers such as Adam don't strictly follow the gradient, but instead follow a modified gradient that includes momentum. This means that the gradient step can change the norm of the dictionary vectors, even when the hook _weight_backward_hook is applied.

Note this can't be applied directly in the backward hook, as it would interfere with a variety of use cases (e.g. gradient accumulation across mini-batches, concurrency issues with asynchronous operations, etc).

Example

import torch layer = UnitNormDecoder(3, 3, None) layer.weight.data = torch.ones((3, 3)) * 10 layer.constrain_weights_unit_norm() column_norms = torch.sqrt(torch.sum(layer.weight ** 2, dim=0)) column_norms.round(decimals=3).tolist() [1.0, 1.0, 1.0]

Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def constrain_weights_unit_norm(self) -> None:
    """Constrain the weights to have unit norm.

    Warning:
        Note this must be called after each gradient step. This is because optimisers such as
        Adam don't strictly follow the gradient, but instead follow a modified gradient that
        includes momentum. This means that the gradient step can change the norm of the
        dictionary vectors, even when the hook `_weight_backward_hook` is applied.

        Note this can't be applied directly in the backward hook, as it would interfere with a
        variety of use cases (e.g. gradient accumulation across mini-batches, concurrency issues
        with asynchronous operations, etc).

    Example:
        >>> import torch
        >>> layer = UnitNormDecoder(3, 3, None)
        >>> layer.weight.data = torch.ones((3, 3)) * 10
        >>> layer.constrain_weights_unit_norm()
        >>> column_norms = torch.sqrt(torch.sum(layer.weight ** 2, dim=0))
        >>> column_norms.round(decimals=3).tolist()
        [1.0, 1.0, 1.0]

    """
    with torch.no_grad():
        torch.nn.functional.normalize(self.weight, dim=-2, out=self.weight)

extra_repr() ¤

String extra representation of the module.

Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
293
294
295
296
297
298
299
def extra_repr(self) -> str:
    """String extra representation of the module."""
    return (
        f"learnt_features={self._learnt_features}, "
        f"decoded_features={self._decoded_features}, "
        f"n_components={self._n_components}"
    )

forward(x) ¤

Forward pass.

Parameters:

Name Type Description Default
x Float[Tensor, names(BATCH, COMPONENT_OPTIONAL, LEARNT_FEATURE)]

Input tensor.

required

Returns:

Type Description
Float[Tensor, names(BATCH, COMPONENT_OPTIONAL, INPUT_OUTPUT_FEATURE)]

Output of the forward pass.

Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def forward(
    self, x: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]:
    """Forward pass.

    Args:
        x: Input tensor.

    Returns:
        Output of the forward pass.
    """
    return einops.einsum(
        x,
        self.weight,
        f"{Axis.BATCH} ... {Axis.LEARNT_FEATURE}, \
        ... {Axis.INPUT_OUTPUT_FEATURE} {Axis.LEARNT_FEATURE} \
            -> {Axis.BATCH} ... {Axis.INPUT_OUTPUT_FEATURE}",
    )

reset_parameters() ¤

Initialize or reset the parameters.

Example

import torch

Create a layer with 4 columns (learnt features) and 3 rows (decoded features)¤

layer = UnitNormDecoder(learnt_features=4, decoded_features=3, n_components=None) layer.reset_parameters()

Get the norm across the rows (by summing across the columns)¤

column_norms = torch.sum(layer.weight ** 2, dim=0) column_norms.round(decimals=3).tolist() [1.0, 1.0, 1.0, 1.0]

Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def reset_parameters(self) -> None:
    """Initialize or reset the parameters.

    Example:
        >>> import torch
        >>> # Create a layer with 4 columns (learnt features) and 3 rows (decoded features)
        >>> layer = UnitNormDecoder(learnt_features=4, decoded_features=3, n_components=None)
        >>> layer.reset_parameters()
        >>> # Get the norm across the rows (by summing across the columns)
        >>> column_norms = torch.sum(layer.weight ** 2, dim=0)
        >>> column_norms.round(decimals=3).tolist()
        [1.0, 1.0, 1.0, 1.0]

    """
    # Initialize the weights with a normal distribution. Note we don't use e.g. kaiming
    # normalisation here, since we immediately scale the weights to have unit norm (so the
    # initial standard deviation doesn't matter). Note also that `init.normal_` is in place.
    self.weight: Float[
        Parameter,
        Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE),
    ] = init.normal_(self.weight, mean=0, std=1)  # type: ignore

    # Scale so that each row has unit norm
    self.constrain_weights_unit_norm()

update_dictionary_vectors(dictionary_vector_indices, updated_weights, component_idx=None) ¤

Update decoder dictionary vectors.

Updates the dictionary vectors (rows in the weight matrix) with the given values. Typically this is used when resampling neurons (dictionary vectors) that have died.

Parameters:

Name Type Description Default
dictionary_vector_indices Int64[Tensor, names(COMPONENT_OPTIONAL, LEARNT_FEATURE_IDX)]

Indices of the dictionary vectors to update.

required
updated_weights Float[Tensor, names(COMPONENT_OPTIONAL, INPUT_OUTPUT_FEATURE, LEARNT_FEATURE_IDX)]

Updated weights for just these dictionary vectors.

required
component_idx int | None

Component index to update.

None

Raises:

Type Description
ValueError

If component_idx is not specified when n_components is not None.

Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def update_dictionary_vectors(
    self,
    dictionary_vector_indices: Int64[
        Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX)
    ],
    updated_weights: Float[
        Tensor,
        Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE, Axis.LEARNT_FEATURE_IDX),
    ],
    component_idx: int | None = None,
) -> None:
    """Update decoder dictionary vectors.

    Updates the dictionary vectors (rows in the weight matrix) with the given values. Typically
    this is used when resampling neurons (dictionary vectors) that have died.

    Args:
        dictionary_vector_indices: Indices of the dictionary vectors to update.
        updated_weights: Updated weights for just these dictionary vectors.
        component_idx: Component index to update.

    Raises:
        ValueError: If `component_idx` is not specified when `n_components` is not None.
    """
    if dictionary_vector_indices.numel() == 0:
        return

    with torch.no_grad():
        if component_idx is None:
            if self._n_components is not None:
                error_message = "component_idx must be specified when n_components is not None"
                raise ValueError(error_message)

            self.weight[:, dictionary_vector_indices] = updated_weights
        else:
            self.weight[component_idx, :, dictionary_vector_indices] = updated_weights