Skip to content

Adam Optimizer with a reset method¤

Adam Optimizer with a reset method.

This reset method is useful when resampling dead neurons during training.

AdamWithReset ¤

Bases: Adam

Adam Optimizer with a reset method.

The :meth:reset_state_all_parameters and :meth:reset_neurons_state methods are useful when manually editing the model parameters during training (e.g. when resampling dead neurons). This is because Adam maintains running averages of the gradients and the squares of gradients, which will be incorrect if the parameters are changed.

Otherwise this is the same as the standard Adam optimizer.

Source code in sparse_autoencoder/optimizer/adam_with_reset.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
class AdamWithReset(Adam):
    """Adam Optimizer with a reset method.

    The :meth:`reset_state_all_parameters` and :meth:`reset_neurons_state` methods are useful when
    manually editing the model parameters during training (e.g. when resampling dead neurons). This
    is because Adam maintains running averages of the gradients and the squares of gradients, which
    will be incorrect if the parameters are changed.

    Otherwise this is the same as the standard Adam optimizer.
    """

    parameter_names: list[str]
    """Parameter Names.

    The names of the parameters, so that we can find them later when resetting the state.
    """

    _has_components_dim: bool
    """Whether the parameters have a components dimension."""

    def __init__(  # (extending existing implementation)
        self,
        params: params_t,
        lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 0.0,
        *,
        amsgrad: bool = False,
        foreach: bool | None = None,
        maximize: bool = False,
        capturable: bool = False,
        differentiable: bool = False,
        fused: bool | None = None,
        named_parameters: Iterator[tuple[str, Parameter]],
        has_components_dim: bool,
    ) -> None:
        """Initialize the optimizer.

        Warning:
            Named parameters must be with default settings (remove duplicates and not recursive).

        Example:
            >>> import torch
            >>> from sparse_autoencoder.autoencoder.model import (
            ...     SparseAutoencoder, SparseAutoencoderConfig
            ... )
            >>> model = SparseAutoencoder(
            ...        SparseAutoencoderConfig(
            ...             n_input_features=5,
            ...             n_learned_features=10,
            ...             n_components=2
            ...         )
            ...    )
            >>> optimizer = AdamWithReset(
            ...     model.parameters(),
            ...     named_parameters=model.named_parameters(),
            ...     has_components_dim=True,
            ... )
            >>> optimizer.reset_state_all_parameters()

        Args:
            params: Iterable of parameters to optimize or dicts defining parameter groups.
            lr: Learning rate. A Tensor LR is not yet fully supported for all implementations. Use a
                float LR unless specifying fused=True or capturable=True.
            betas: Coefficients used for computing running averages of gradient and its square.
            eps: Term added to the denominator to improve numerical stability.
            weight_decay: Weight decay (L2 penalty).
            amsgrad: Whether to use the AMSGrad variant of this algorithm from the paper "On the
                Convergence of Adam and Beyond".
            foreach: Whether foreach implementation of optimizer is used. If None, foreach is used
                over the for-loop implementation on CUDA if more performant. Note that foreach uses
                more peak memory.
            maximize: If True, maximizes the parameters based on the objective, instead of
                minimizing.
            capturable: Whether this instance is safe to capture in a CUDA graph. True can impair
                ungraphed performance.
            differentiable: Whether autograd should occur through the optimizer step in training.
                Setting to True can impair performance.
            fused: Whether the fused implementation (CUDA only) is used. Supports torch.float64,
                torch.float32, torch.float16, and torch.bfloat16.
            named_parameters: An iterator over the named parameters of the model. This is used to
                find the parameters when resetting their state. You should set this as
                `model.named_parameters()`.
            has_components_dim: If the parameters have a components dimension (i.e. if you are
                training an SAE on more than one component).

        Raises:
            ValueError: If the number of parameter names does not match the number of parameters.
        """
        # Initialise the parent class (note we repeat the parameter names so that type hints work).
        super().__init__(
            params=params,
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            amsgrad=amsgrad,
            foreach=foreach,
            maximize=maximize,
            capturable=capturable,
            differentiable=differentiable,
            fused=fused,
        )

        self._has_components_dim = has_components_dim

        # Store the names of the parameters, so that we can find them later when resetting the
        # state.
        self.parameter_names = [name for name, _value in named_parameters]

        if len(self.parameter_names) != len(self.param_groups[0]["params"]):
            error_message = (
                "The number of parameter names does not match the number of parameters. "
                "If using model.named_parameters() make sure remove_duplicates is True "
                "and recursive is False (the default settings)."
            )
            raise ValueError(error_message)

    def reset_state_all_parameters(self) -> None:
        """Reset the state for all parameters.

        Iterates over all parameters and resets both the running averages of the gradients and the
        squares of gradients.
        """
        # Iterate over every parameter
        for group in self.param_groups:
            for parameter in group["params"]:
                # Get the state
                state = self.state[parameter]

                # Check if state is initialized
                if len(state) == 0:
                    continue

                # Reset running averages
                exp_avg: Tensor = state["exp_avg"]
                exp_avg.zero_()
                exp_avg_sq: Tensor = state["exp_avg_sq"]
                exp_avg_sq.zero_()

                # If AdamW is used (weight decay fix), also reset the max exp_avg_sq
                if "max_exp_avg_sq" in state:
                    max_exp_avg_sq: Tensor = state["max_exp_avg_sq"]
                    max_exp_avg_sq.zero_()

    def reset_neurons_state(
        self,
        parameter: Parameter,
        neuron_indices: Int[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)],
        axis: int,
        component_idx: int = 0,
    ) -> None:
        """Reset the state for specific neurons, on a specific parameter.

        Example:
            >>> import torch
            >>> from sparse_autoencoder.autoencoder.model import (
            ...     SparseAutoencoder, SparseAutoencoderConfig
            ... )
            >>> model = SparseAutoencoder(
            ...        SparseAutoencoderConfig(
            ...             n_input_features=5,
            ...             n_learned_features=10,
            ...             n_components=2
            ...         )
            ...    )
            >>> optimizer = AdamWithReset(
            ...     model.parameters(),
            ...     named_parameters=model.named_parameters(),
            ...     has_components_dim=True,
            ... )
            >>> # ... train the model and then resample some dead neurons, then do this ...
            >>> dead_neurons_indices = torch.tensor([0, 1]) # Dummy dead neuron indices
            >>> # Reset the optimizer state for parameters that have been updated
            >>> optimizer.reset_neurons_state(model.encoder.weight, dead_neurons_indices, axis=0)
            >>> optimizer.reset_neurons_state(model.encoder.bias, dead_neurons_indices, axis=0)
            >>> optimizer.reset_neurons_state(
            ...     model.decoder.weight,
            ...     dead_neurons_indices,
            ...     axis=1
            ... )

        Args:
            parameter: The parameter to be reset. Examples from the standard sparse autoencoder
                implementation  include `tied_bias`, `_encoder._weight`, `_encoder._bias`,
            neuron_indices: The indices of the neurons to reset.
            axis: The axis of the state values to reset (i.e. the input/output features axis, as
                we're resetting all input/output features for a specific dead neuron).
            component_idx: The component index of the state values to reset.

        Raises:
            ValueError: If the parameter has a components dimension, but has_components_dim is
                False.
        """
        # Get the state of the parameter
        state = self.state[parameter]

        # If the number of dimensions is 3, we definitely have a components dimension. If 2, we may
        # do (as the bias has 2 dimensions with components, but the weight has 2 dimensions without
        # components).
        definitely_has_components_dimension = 3
        if (
            not self._has_components_dim
            and state["exp_avg"].ndim == definitely_has_components_dimension
        ):
            error_message = (
                "The parameter has a components dimension, but has_components_dim is False. "
                "This should not happen."
            )
            raise ValueError(error_message)

        # Check if state is initialized
        if len(state) == 0:
            return

        # Check there are any neurons to reset
        if neuron_indices.numel() == 0:
            return

        # Move the neuron indices to the correct device
        neuron_indices = neuron_indices.to(device=state["exp_avg"].device)

        # Reset running averages for the specified neurons
        if "exp_avg" in state:
            if self._has_components_dim:
                state["exp_avg"][component_idx].index_fill_(axis, neuron_indices, 0)
            else:
                state["exp_avg"].index_fill_(axis, neuron_indices, 0)

        if "exp_avg_sq" in state:
            if self._has_components_dim:
                state["exp_avg_sq"][component_idx].index_fill_(axis, neuron_indices, 0)
            else:
                state["exp_avg_sq"].index_fill_(axis, neuron_indices, 0)

        # If AdamW is used (weight decay fix), also reset the max exp_avg_sq
        if "max_exp_avg_sq" in state:
            if self._has_components_dim:
                state["max_exp_avg_sq"][component_idx].index_fill_(axis, neuron_indices, 0)
            else:
                state["max_exp_avg_sq"].index_fill_(axis, neuron_indices, 0)

parameter_names: list[str] = [name for (name, _value) in named_parameters] instance-attribute ¤

Parameter Names.

The names of the parameters, so that we can find them later when resetting the state.

__init__(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, amsgrad=False, foreach=None, maximize=False, capturable=False, differentiable=False, fused=None, named_parameters, has_components_dim) ¤

Initialize the optimizer.

Warning

Named parameters must be with default settings (remove duplicates and not recursive).

Example

import torch from sparse_autoencoder.autoencoder.model import ( ... SparseAutoencoder, SparseAutoencoderConfig ... ) model = SparseAutoencoder( ... SparseAutoencoderConfig( ... n_input_features=5, ... n_learned_features=10, ... n_components=2 ... ) ... ) optimizer = AdamWithReset( ... model.parameters(), ... named_parameters=model.named_parameters(), ... has_components_dim=True, ... ) optimizer.reset_state_all_parameters()

Parameters:

Name Type Description Default
params params_t

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float | Float[Tensor, names(SINGLE_ITEM)]

Learning rate. A Tensor LR is not yet fully supported for all implementations. Use a float LR unless specifying fused=True or capturable=True.

0.001
betas tuple[float, float]

Coefficients used for computing running averages of gradient and its square.

(0.9, 0.999)
eps float

Term added to the denominator to improve numerical stability.

1e-08
weight_decay float

Weight decay (L2 penalty).

0.0
amsgrad bool

Whether to use the AMSGrad variant of this algorithm from the paper "On the Convergence of Adam and Beyond".

False
foreach bool | None

Whether foreach implementation of optimizer is used. If None, foreach is used over the for-loop implementation on CUDA if more performant. Note that foreach uses more peak memory.

None
maximize bool

If True, maximizes the parameters based on the objective, instead of minimizing.

False
capturable bool

Whether this instance is safe to capture in a CUDA graph. True can impair ungraphed performance.

False
differentiable bool

Whether autograd should occur through the optimizer step in training. Setting to True can impair performance.

False
fused bool | None

Whether the fused implementation (CUDA only) is used. Supports torch.float64, torch.float32, torch.float16, and torch.bfloat16.

None
named_parameters Iterator[tuple[str, Parameter]]

An iterator over the named parameters of the model. This is used to find the parameters when resetting their state. You should set this as model.named_parameters().

required
has_components_dim bool

If the parameters have a components dimension (i.e. if you are training an SAE on more than one component).

required

Raises:

Type Description
ValueError

If the number of parameter names does not match the number of parameters.

Source code in sparse_autoencoder/optimizer/adam_with_reset.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def __init__(  # (extending existing implementation)
    self,
    params: params_t,
    lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-8,
    weight_decay: float = 0.0,
    *,
    amsgrad: bool = False,
    foreach: bool | None = None,
    maximize: bool = False,
    capturable: bool = False,
    differentiable: bool = False,
    fused: bool | None = None,
    named_parameters: Iterator[tuple[str, Parameter]],
    has_components_dim: bool,
) -> None:
    """Initialize the optimizer.

    Warning:
        Named parameters must be with default settings (remove duplicates and not recursive).

    Example:
        >>> import torch
        >>> from sparse_autoencoder.autoencoder.model import (
        ...     SparseAutoencoder, SparseAutoencoderConfig
        ... )
        >>> model = SparseAutoencoder(
        ...        SparseAutoencoderConfig(
        ...             n_input_features=5,
        ...             n_learned_features=10,
        ...             n_components=2
        ...         )
        ...    )
        >>> optimizer = AdamWithReset(
        ...     model.parameters(),
        ...     named_parameters=model.named_parameters(),
        ...     has_components_dim=True,
        ... )
        >>> optimizer.reset_state_all_parameters()

    Args:
        params: Iterable of parameters to optimize or dicts defining parameter groups.
        lr: Learning rate. A Tensor LR is not yet fully supported for all implementations. Use a
            float LR unless specifying fused=True or capturable=True.
        betas: Coefficients used for computing running averages of gradient and its square.
        eps: Term added to the denominator to improve numerical stability.
        weight_decay: Weight decay (L2 penalty).
        amsgrad: Whether to use the AMSGrad variant of this algorithm from the paper "On the
            Convergence of Adam and Beyond".
        foreach: Whether foreach implementation of optimizer is used. If None, foreach is used
            over the for-loop implementation on CUDA if more performant. Note that foreach uses
            more peak memory.
        maximize: If True, maximizes the parameters based on the objective, instead of
            minimizing.
        capturable: Whether this instance is safe to capture in a CUDA graph. True can impair
            ungraphed performance.
        differentiable: Whether autograd should occur through the optimizer step in training.
            Setting to True can impair performance.
        fused: Whether the fused implementation (CUDA only) is used. Supports torch.float64,
            torch.float32, torch.float16, and torch.bfloat16.
        named_parameters: An iterator over the named parameters of the model. This is used to
            find the parameters when resetting their state. You should set this as
            `model.named_parameters()`.
        has_components_dim: If the parameters have a components dimension (i.e. if you are
            training an SAE on more than one component).

    Raises:
        ValueError: If the number of parameter names does not match the number of parameters.
    """
    # Initialise the parent class (note we repeat the parameter names so that type hints work).
    super().__init__(
        params=params,
        lr=lr,
        betas=betas,
        eps=eps,
        weight_decay=weight_decay,
        amsgrad=amsgrad,
        foreach=foreach,
        maximize=maximize,
        capturable=capturable,
        differentiable=differentiable,
        fused=fused,
    )

    self._has_components_dim = has_components_dim

    # Store the names of the parameters, so that we can find them later when resetting the
    # state.
    self.parameter_names = [name for name, _value in named_parameters]

    if len(self.parameter_names) != len(self.param_groups[0]["params"]):
        error_message = (
            "The number of parameter names does not match the number of parameters. "
            "If using model.named_parameters() make sure remove_duplicates is True "
            "and recursive is False (the default settings)."
        )
        raise ValueError(error_message)

reset_neurons_state(parameter, neuron_indices, axis, component_idx=0) ¤

Reset the state for specific neurons, on a specific parameter.

Example

import torch from sparse_autoencoder.autoencoder.model import ( ... SparseAutoencoder, SparseAutoencoderConfig ... ) model = SparseAutoencoder( ... SparseAutoencoderConfig( ... n_input_features=5, ... n_learned_features=10, ... n_components=2 ... ) ... ) optimizer = AdamWithReset( ... model.parameters(), ... named_parameters=model.named_parameters(), ... has_components_dim=True, ... )

... train the model and then resample some dead neurons, then do this ...¤

dead_neurons_indices = torch.tensor([0, 1]) # Dummy dead neuron indices

Reset the optimizer state for parameters that have been updated¤

optimizer.reset_neurons_state(model.encoder.weight, dead_neurons_indices, axis=0) optimizer.reset_neurons_state(model.encoder.bias, dead_neurons_indices, axis=0) optimizer.reset_neurons_state( ... model.decoder.weight, ... dead_neurons_indices, ... axis=1 ... )

Parameters:

Name Type Description Default
parameter Parameter

The parameter to be reset. Examples from the standard sparse autoencoder implementation include tied_bias, _encoder._weight, _encoder._bias,

required
neuron_indices Int[Tensor, names(LEARNT_FEATURE_IDX)]

The indices of the neurons to reset.

required
axis int

The axis of the state values to reset (i.e. the input/output features axis, as we're resetting all input/output features for a specific dead neuron).

required
component_idx int

The component index of the state values to reset.

0

Raises:

Type Description
ValueError

If the parameter has a components dimension, but has_components_dim is False.

Source code in sparse_autoencoder/optimizer/adam_with_reset.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def reset_neurons_state(
    self,
    parameter: Parameter,
    neuron_indices: Int[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)],
    axis: int,
    component_idx: int = 0,
) -> None:
    """Reset the state for specific neurons, on a specific parameter.

    Example:
        >>> import torch
        >>> from sparse_autoencoder.autoencoder.model import (
        ...     SparseAutoencoder, SparseAutoencoderConfig
        ... )
        >>> model = SparseAutoencoder(
        ...        SparseAutoencoderConfig(
        ...             n_input_features=5,
        ...             n_learned_features=10,
        ...             n_components=2
        ...         )
        ...    )
        >>> optimizer = AdamWithReset(
        ...     model.parameters(),
        ...     named_parameters=model.named_parameters(),
        ...     has_components_dim=True,
        ... )
        >>> # ... train the model and then resample some dead neurons, then do this ...
        >>> dead_neurons_indices = torch.tensor([0, 1]) # Dummy dead neuron indices
        >>> # Reset the optimizer state for parameters that have been updated
        >>> optimizer.reset_neurons_state(model.encoder.weight, dead_neurons_indices, axis=0)
        >>> optimizer.reset_neurons_state(model.encoder.bias, dead_neurons_indices, axis=0)
        >>> optimizer.reset_neurons_state(
        ...     model.decoder.weight,
        ...     dead_neurons_indices,
        ...     axis=1
        ... )

    Args:
        parameter: The parameter to be reset. Examples from the standard sparse autoencoder
            implementation  include `tied_bias`, `_encoder._weight`, `_encoder._bias`,
        neuron_indices: The indices of the neurons to reset.
        axis: The axis of the state values to reset (i.e. the input/output features axis, as
            we're resetting all input/output features for a specific dead neuron).
        component_idx: The component index of the state values to reset.

    Raises:
        ValueError: If the parameter has a components dimension, but has_components_dim is
            False.
    """
    # Get the state of the parameter
    state = self.state[parameter]

    # If the number of dimensions is 3, we definitely have a components dimension. If 2, we may
    # do (as the bias has 2 dimensions with components, but the weight has 2 dimensions without
    # components).
    definitely_has_components_dimension = 3
    if (
        not self._has_components_dim
        and state["exp_avg"].ndim == definitely_has_components_dimension
    ):
        error_message = (
            "The parameter has a components dimension, but has_components_dim is False. "
            "This should not happen."
        )
        raise ValueError(error_message)

    # Check if state is initialized
    if len(state) == 0:
        return

    # Check there are any neurons to reset
    if neuron_indices.numel() == 0:
        return

    # Move the neuron indices to the correct device
    neuron_indices = neuron_indices.to(device=state["exp_avg"].device)

    # Reset running averages for the specified neurons
    if "exp_avg" in state:
        if self._has_components_dim:
            state["exp_avg"][component_idx].index_fill_(axis, neuron_indices, 0)
        else:
            state["exp_avg"].index_fill_(axis, neuron_indices, 0)

    if "exp_avg_sq" in state:
        if self._has_components_dim:
            state["exp_avg_sq"][component_idx].index_fill_(axis, neuron_indices, 0)
        else:
            state["exp_avg_sq"].index_fill_(axis, neuron_indices, 0)

    # If AdamW is used (weight decay fix), also reset the max exp_avg_sq
    if "max_exp_avg_sq" in state:
        if self._has_components_dim:
            state["max_exp_avg_sq"][component_idx].index_fill_(axis, neuron_indices, 0)
        else:
            state["max_exp_avg_sq"].index_fill_(axis, neuron_indices, 0)

reset_state_all_parameters() ¤

Reset the state for all parameters.

Iterates over all parameters and resets both the running averages of the gradients and the squares of gradients.

Source code in sparse_autoencoder/optimizer/adam_with_reset.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def reset_state_all_parameters(self) -> None:
    """Reset the state for all parameters.

    Iterates over all parameters and resets both the running averages of the gradients and the
    squares of gradients.
    """
    # Iterate over every parameter
    for group in self.param_groups:
        for parameter in group["params"]:
            # Get the state
            state = self.state[parameter]

            # Check if state is initialized
            if len(state) == 0:
                continue

            # Reset running averages
            exp_avg: Tensor = state["exp_avg"]
            exp_avg.zero_()
            exp_avg_sq: Tensor = state["exp_avg_sq"]
            exp_avg_sq.zero_()

            # If AdamW is used (weight decay fix), also reset the max exp_avg_sq
            if "max_exp_avg_sq" in state:
                max_exp_avg_sq: Tensor = state["max_exp_avg_sq"]
                max_exp_avg_sq.zero_()