Skip to content

Replace activations hook¤

Replace activations hook.

replace_activations_hook(value, hook, sparse_autoencoder, component_idx=None, n_components=None) ¤

Replace activations hook.

This should be pre-initialised with functools.partial.

Parameters:

Name Type Description Default
value Tensor

The activations to replace.

required
hook HookPoint

The hook point.

required
sparse_autoencoder SparseAutoencoder | DataParallel[SparseAutoencoder] | LitSparseAutoencoder | Module

The sparse autoencoder.

required
component_idx int | None

The component index to replace the activations with, if just replacing activations for a single component. Requires the model to have a component axis.

None
n_components int | None

The number of components that the SAE is trained on.

None

Returns:

Type Description
Tensor

Replaced activations.

Raises:

Type Description
RuntimeError

If component_idx is specified, but the model does not have a component

Source code in sparse_autoencoder/source_model/replace_activations_hook.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def replace_activations_hook(
    value: Tensor,
    hook: HookPoint,  # noqa: ARG001
    sparse_autoencoder: SparseAutoencoder
    | DataParallel[SparseAutoencoder]
    | LitSparseAutoencoder
    | Module,
    component_idx: int | None = None,
    n_components: int | None = None,
) -> Tensor:
    """Replace activations hook.

    This should be pre-initialised with `functools.partial`.

    Args:
        value: The activations to replace.
        hook: The hook point.
        sparse_autoencoder: The sparse autoencoder.
        component_idx: The component index to replace the activations with, if just replacing
            activations for a single component. Requires the model to have a component axis.
        n_components: The number of components that the SAE is trained on.

    Returns:
        Replaced activations.

    Raises:
        RuntimeError: If `component_idx` is specified, but the model does not have a component
    """
    # Squash to just have a "*items" and a "batch" dimension
    original_shape = value.shape

    squashed_value: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] = value.view(
        -1, value.size(-1)
    )

    if component_idx is not None:
        if n_components is None:
            error_message = "The number of model components must be set if component_idx is set."
            raise RuntimeError(error_message)

        # The approach here is to run a forward pass with dummy values for all components other than
        # the one we want to replace. This is done by expanding the inputs to the SAE for a specific
        # component across all components. We then simply discard the activations for all other
        # components.
        expanded_shape = [
            squashed_value.shape[0],
            n_components,
            squashed_value.shape[-1],
        ]
        expanded = squashed_value.unsqueeze(1).expand(*expanded_shape)

        _learned_activations, output_activations = sparse_autoencoder.forward(expanded)
        component_output_activations = output_activations[:, component_idx]

        return component_output_activations.view(*original_shape)

    # Get the output activations from a forward pass of the SAE
    _learned_activations, output_activations = sparse_autoencoder.forward(squashed_value)

    # Reshape to the original shape
    return output_activations.view(*original_shape)