Linear layer with unit norm weights¤
Linear layer with unit norm weights.
UnitNormDecoder
¤
Bases: Module
Constrained unit norm linear decoder layer.
Linear layer decoder, where the dictionary vectors (columns of the weight matrix) are
constrained to have unit norm. This is done by removing the gradient information parallel to the
dictionary vectors before applying the gradient step, using a backward hook. It also requires
constrain_weights_unit_norm
to be called after each gradient step, to prevent drift of the
dictionary vectors away from unit norm (as optimisers such as Adam don't strictly follow the
gradient, but instead follow a modified gradient that includes momentum).
Motivation
Normalisation of the columns (dictionary features) prevents the model from reducing the sparsity loss term by increasing the size of the feature vectors in \(W_d\).
Note that the Towards Monosemanticity: Decomposing Language Models With Dictionary Learning paper found that removing the gradient information parallel to the dictionary vectors before applying the gradient step, rather than resetting the dictionary vectors to unit norm after each gradient step, results in a small but real reduction in total loss](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-optimization).
Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
|
reset_optimizer_parameter_details: list[ResetOptimizerParameterDetails]
property
¤
Reset optimizer parameter details.
Details of the parameters that should be reset in the optimizer, when resetting dictionary vectors.
Returns:
Type | Description |
---|---|
list[ResetOptimizerParameterDetails]
|
List of tuples of the form |
list[ResetOptimizerParameterDetails]
|
reset (e.g. encoder.weight), and |
weight: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE, Axis.LEARNT_FEATURE)] = Parameter(torch.empty(shape_with_optional_dimensions(n_components, decoded_features, learnt_features)))
instance-attribute
¤
Weight parameter.
Each column in the weights matrix acts as a dictionary vector, representing a single basis element in the learned activation space.
__init__(learnt_features, decoded_features, n_components, *, enable_gradient_hook=True)
¤
Initialize the constrained unit norm linear layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
learnt_features |
PositiveInt
|
Number of learnt features in the autoencoder. |
required |
decoded_features |
PositiveInt
|
Number of decoded (output) features in the autoencoder. |
required |
n_components |
PositiveInt | None
|
Number of source model components the SAE is trained on. |
required |
enable_gradient_hook |
bool
|
Enable the gradient backwards hook (modify the gradient before applying the gradient step, to maintain unit norm of the dictionary vectors). |
True
|
Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
|
constrain_weights_unit_norm()
¤
Constrain the weights to have unit norm.
Warning
Note this must be called after each gradient step. This is because optimisers such as
Adam don't strictly follow the gradient, but instead follow a modified gradient that
includes momentum. This means that the gradient step can change the norm of the
dictionary vectors, even when the hook _weight_backward_hook
is applied.
Note this can't be applied directly in the backward hook, as it would interfere with a variety of use cases (e.g. gradient accumulation across mini-batches, concurrency issues with asynchronous operations, etc).
Example
import torch layer = UnitNormDecoder(3, 3, None) layer.weight.data = torch.ones((3, 3)) * 10 layer.constrain_weights_unit_norm() column_norms = torch.sqrt(torch.sum(layer.weight ** 2, dim=0)) column_norms.round(decimals=3).tolist() [1.0, 1.0, 1.0]
Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
|
extra_repr()
¤
String extra representation of the module.
Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
293 294 295 296 297 298 299 |
|
forward(x)
¤
Forward pass.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Float[Tensor, names(BATCH, COMPONENT_OPTIONAL, LEARNT_FEATURE)]
|
Input tensor. |
required |
Returns:
Type | Description |
---|---|
Float[Tensor, names(BATCH, COMPONENT_OPTIONAL, INPUT_OUTPUT_FEATURE)]
|
Output of the forward pass. |
Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
|
reset_parameters()
¤
Initialize or reset the parameters.
Example
import torch
Create a layer with 4 columns (learnt features) and 3 rows (decoded features)¤
layer = UnitNormDecoder(learnt_features=4, decoded_features=3, n_components=None) layer.reset_parameters()
Get the norm across the rows (by summing across the columns)¤
column_norms = torch.sum(layer.weight ** 2, dim=0) column_norms.round(decimals=3).tolist() [1.0, 1.0, 1.0, 1.0]
Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
|
update_dictionary_vectors(dictionary_vector_indices, updated_weights, component_idx=None)
¤
Update decoder dictionary vectors.
Updates the dictionary vectors (rows in the weight matrix) with the given values. Typically this is used when resampling neurons (dictionary vectors) that have died.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dictionary_vector_indices |
Int64[Tensor, names(COMPONENT_OPTIONAL, LEARNT_FEATURE_IDX)]
|
Indices of the dictionary vectors to update. |
required |
updated_weights |
Float[Tensor, names(COMPONENT_OPTIONAL, INPUT_OUTPUT_FEATURE, LEARNT_FEATURE_IDX)]
|
Updated weights for just these dictionary vectors. |
required |
component_idx |
int | None
|
Component index to update. |
None
|
Raises:
Type | Description |
---|---|
ValueError
|
If |
Source code in sparse_autoencoder/autoencoder/components/unit_norm_decoder.py
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
|