TransformerLens Hook for storing activations¤
TransformerLens Hook for storing activations.
store_activations_hook(value, hook, store, reshape_method=reshape_to_last_dimension, component_idx=0)
¤
Store Activations Hook.
Useful for getting just the specific activations wanted, rather than the full cache.
Example
First we'll need a source model from TransformerLens and an activation store.
from functools import partial from transformer_lens import HookedTransformer from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore store = TensorActivationStore(max_items=1000, n_neurons=64, n_components=1) model = HookedTransformer.from_pretrained("tiny-stories-1M") Loaded pretrained model tiny-stories-1M into HookedTransformer
Next we can add the hook to specific neurons (in this case the first MLP neurons), and create the tokens for a forward pass.
model.add_hook( ... "blocks.0.hook_mlp_out", partial(store_activations_hook, store=store) ... ) tokens = model.to_tokens("Hello world") tokens.shape torch.Size([1, 3])
Then when we run the model, we should get one activation vector for each token (as we just
have one batch item). Note we also set stop_at_layer=1
as we don't need the logits or any
other activations after the hook point that we've specified (in this case the first MLP
layer).
_output = model.forward("Hello world", stop_at_layer=1) # Change this layer as required len(store) 3
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Float[Tensor, names(ANY)]
|
The activations to store. |
required |
hook |
HookPoint
|
The hook point. |
required |
store |
ActivationStore
|
The activation store. This should be pre-initialised with |
required |
reshape_method |
ReshapeActivationsFunction
|
The method to reshape the activations before storing them. |
reshape_to_last_dimension
|
component_idx |
int
|
The component index of the activations to store. |
0
|
Returns:
Type | Description |
---|---|
Float[Tensor, names(ANY)]
|
Unmodified activations. |
Source code in sparse_autoencoder/source_model/store_activations_hook.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
|