Skip to content

TransformerLens Hook for storing activations¤

TransformerLens Hook for storing activations.

store_activations_hook(value, hook, store, reshape_method=reshape_to_last_dimension, component_idx=0) ¤

Store Activations Hook.

Useful for getting just the specific activations wanted, rather than the full cache.

Example

First we'll need a source model from TransformerLens and an activation store.

from functools import partial from transformer_lens import HookedTransformer from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore store = TensorActivationStore(max_items=1000, n_neurons=64, n_components=1) model = HookedTransformer.from_pretrained("tiny-stories-1M") Loaded pretrained model tiny-stories-1M into HookedTransformer

Next we can add the hook to specific neurons (in this case the first MLP neurons), and create the tokens for a forward pass.

model.add_hook( ... "blocks.0.hook_mlp_out", partial(store_activations_hook, store=store) ... ) tokens = model.to_tokens("Hello world") tokens.shape torch.Size([1, 3])

Then when we run the model, we should get one activation vector for each token (as we just have one batch item). Note we also set stop_at_layer=1 as we don't need the logits or any other activations after the hook point that we've specified (in this case the first MLP layer).

_output = model.forward("Hello world", stop_at_layer=1) # Change this layer as required len(store) 3

Parameters:

Name Type Description Default
value Float[Tensor, names(ANY)]

The activations to store.

required
hook HookPoint

The hook point.

required
store ActivationStore

The activation store. This should be pre-initialised with functools.partial.

required
reshape_method ReshapeActivationsFunction

The method to reshape the activations before storing them.

reshape_to_last_dimension
component_idx int

The component index of the activations to store.

0

Returns:

Type Description
Float[Tensor, names(ANY)]

Unmodified activations.

Source code in sparse_autoencoder/source_model/store_activations_hook.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def store_activations_hook(
    value: Float[Tensor, Axis.names(Axis.ANY)],
    hook: HookPoint,  # noqa: ARG001
    store: ActivationStore,
    reshape_method: ReshapeActivationsFunction = reshape_to_last_dimension,
    component_idx: int = 0,
) -> Float[Tensor, Axis.names(Axis.ANY)]:
    """Store Activations Hook.

    Useful for getting just the specific activations wanted, rather than the full cache.

    Example:
        First we'll need a source model from TransformerLens and an activation store.

        >>> from functools import partial
        >>> from transformer_lens import HookedTransformer
        >>> from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
        >>> store = TensorActivationStore(max_items=1000, n_neurons=64, n_components=1)
        >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
        Loaded pretrained model tiny-stories-1M into HookedTransformer

        Next we can add the hook to specific neurons (in this case the first MLP neurons), and
        create the tokens for a forward pass.

        >>> model.add_hook(
        ...     "blocks.0.hook_mlp_out", partial(store_activations_hook, store=store)
        ... )
        >>> tokens = model.to_tokens("Hello world")
        >>> tokens.shape
        torch.Size([1, 3])

        Then when we run the model, we should get one activation vector for each token (as we just
        have one batch item). Note we also set `stop_at_layer=1` as we don't need the logits or any
        other activations after the hook point that we've specified (in this case the first MLP
        layer).

        >>> _output = model.forward("Hello world", stop_at_layer=1) # Change this layer as required
        >>> len(store)
        3

    Args:
        value: The activations to store.
        hook: The hook point.
        store: The activation store. This should be pre-initialised with `functools.partial`.
        reshape_method: The method to reshape the activations before storing them.
        component_idx: The component index of the activations to store.

    Returns:
        Unmodified activations.
    """
    reshaped: Float[
        Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)
    ] = reshape_method(value)

    store.extend(reshaped, component_idx=component_idx)

    # Return the unmodified value
    return value