Adam Optimizer with a reset method¤
Adam Optimizer with a reset method.
This reset method is useful when resampling dead neurons during training.
AdamWithReset
¤
Bases: Adam
Adam Optimizer with a reset method.
The :meth:reset_state_all_parameters
and :meth:reset_neurons_state
methods are useful when
manually editing the model parameters during training (e.g. when resampling dead neurons). This
is because Adam maintains running averages of the gradients and the squares of gradients, which
will be incorrect if the parameters are changed.
Otherwise this is the same as the standard Adam optimizer.
Source code in sparse_autoencoder/optimizer/adam_with_reset.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
|
parameter_names: list[str] = [name for (name, _value) in named_parameters]
instance-attribute
¤
Parameter Names.
The names of the parameters, so that we can find them later when resetting the state.
__init__(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, *, amsgrad=False, foreach=None, maximize=False, capturable=False, differentiable=False, fused=None, named_parameters, has_components_dim)
¤
Initialize the optimizer.
Warning
Named parameters must be with default settings (remove duplicates and not recursive).
Example
import torch from sparse_autoencoder.autoencoder.model import ( ... SparseAutoencoder, SparseAutoencoderConfig ... ) model = SparseAutoencoder( ... SparseAutoencoderConfig( ... n_input_features=5, ... n_learned_features=10, ... n_components=2 ... ) ... ) optimizer = AdamWithReset( ... model.parameters(), ... named_parameters=model.named_parameters(), ... has_components_dim=True, ... ) optimizer.reset_state_all_parameters()
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
params_t
|
Iterable of parameters to optimize or dicts defining parameter groups. |
required |
lr |
float | Float[Tensor, names(SINGLE_ITEM)]
|
Learning rate. A Tensor LR is not yet fully supported for all implementations. Use a float LR unless specifying fused=True or capturable=True. |
0.001
|
betas |
tuple[float, float]
|
Coefficients used for computing running averages of gradient and its square. |
(0.9, 0.999)
|
eps |
float
|
Term added to the denominator to improve numerical stability. |
1e-08
|
weight_decay |
float
|
Weight decay (L2 penalty). |
0.0
|
amsgrad |
bool
|
Whether to use the AMSGrad variant of this algorithm from the paper "On the Convergence of Adam and Beyond". |
False
|
foreach |
bool | None
|
Whether foreach implementation of optimizer is used. If None, foreach is used over the for-loop implementation on CUDA if more performant. Note that foreach uses more peak memory. |
None
|
maximize |
bool
|
If True, maximizes the parameters based on the objective, instead of minimizing. |
False
|
capturable |
bool
|
Whether this instance is safe to capture in a CUDA graph. True can impair ungraphed performance. |
False
|
differentiable |
bool
|
Whether autograd should occur through the optimizer step in training. Setting to True can impair performance. |
False
|
fused |
bool | None
|
Whether the fused implementation (CUDA only) is used. Supports torch.float64, torch.float32, torch.float16, and torch.bfloat16. |
None
|
named_parameters |
Iterator[tuple[str, Parameter]]
|
An iterator over the named parameters of the model. This is used to
find the parameters when resetting their state. You should set this as
|
required |
has_components_dim |
bool
|
If the parameters have a components dimension (i.e. if you are training an SAE on more than one component). |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If the number of parameter names does not match the number of parameters. |
Source code in sparse_autoencoder/optimizer/adam_with_reset.py
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
|
reset_neurons_state(parameter, neuron_indices, axis, component_idx=0)
¤
Reset the state for specific neurons, on a specific parameter.
Example
import torch from sparse_autoencoder.autoencoder.model import ( ... SparseAutoencoder, SparseAutoencoderConfig ... ) model = SparseAutoencoder( ... SparseAutoencoderConfig( ... n_input_features=5, ... n_learned_features=10, ... n_components=2 ... ) ... ) optimizer = AdamWithReset( ... model.parameters(), ... named_parameters=model.named_parameters(), ... has_components_dim=True, ... )
... train the model and then resample some dead neurons, then do this ...¤
dead_neurons_indices = torch.tensor([0, 1]) # Dummy dead neuron indices
Reset the optimizer state for parameters that have been updated¤
optimizer.reset_neurons_state(model.encoder.weight, dead_neurons_indices, axis=0) optimizer.reset_neurons_state(model.encoder.bias, dead_neurons_indices, axis=0) optimizer.reset_neurons_state( ... model.decoder.weight, ... dead_neurons_indices, ... axis=1 ... )
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parameter |
Parameter
|
The parameter to be reset. Examples from the standard sparse autoencoder
implementation include |
required |
neuron_indices |
Int[Tensor, names(LEARNT_FEATURE_IDX)]
|
The indices of the neurons to reset. |
required |
axis |
int
|
The axis of the state values to reset (i.e. the input/output features axis, as we're resetting all input/output features for a specific dead neuron). |
required |
component_idx |
int
|
The component index of the state values to reset. |
0
|
Raises:
Type | Description |
---|---|
ValueError
|
If the parameter has a components dimension, but has_components_dim is False. |
Source code in sparse_autoencoder/optimizer/adam_with_reset.py
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
|
reset_state_all_parameters()
¤
Reset the state for all parameters.
Iterates over all parameters and resets both the running averages of the gradients and the squares of gradients.
Source code in sparse_autoencoder/optimizer/adam_with_reset.py
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
|