Skip to content

Methods to reshape activation tensors¤

Methods to reshape activation tensors.

ReshapeActivationsFunction: TypeAlias = Callable[[Float[Tensor, Axis.names(Axis.ANY)]], Float[Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)]] module-attribute ¤

Reshape Activations Function.

Used within hooks to e.g. reshape activations before storing them in the activation store.

reshape_concat_last_dimensions(batch_activations, concat_dims) ¤

Reshape to Last Dimension, Concatenating the Specified Dimensions.

Takes a tensor of activation vectors, with arbitrary numbers of dimensions (the last concat_dims of which are the neuron dimensions), and returns a single tensor of size [item, neurons].

Examples:

With 3 axis (e.g. batch, pos, neuron), concatenating the last 2 dimensions:

>>> import torch
>>> input = torch.randn(3, 4, 5)
>>> res = reshape_concat_last_dimensions(input, 2)
>>> res.shape
torch.Size([3, 20])

With 4 axis (e.g. batch, pos, head_idx, neuron), concatenating the last 3 dimensions:

>>> input = torch.rand(2, 3, 4, 5)
>>> res = reshape_concat_last_dimensions(input, 3)
>>> res.shape
torch.Size([2, 60])

Parameters:

Name Type Description Default
batch_activations Float[Tensor, names(ANY)]

Input Activation Store Batch

required
concat_dims int

Number of dimensions to concatenate

required

Returns:

Type Description
Float[Tensor, names(STORE_BATCH, INPUT_OUTPUT_FEATURE)]

Single Tensor of Activation Store Items

Source code in sparse_autoencoder/source_model/reshape_activations.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def reshape_concat_last_dimensions(
    batch_activations: Float[Tensor, Axis.names(Axis.ANY)],
    concat_dims: int,
) -> Float[Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)]:
    """Reshape to Last Dimension, Concatenating the Specified Dimensions.

    Takes a tensor of activation vectors, with arbitrary numbers of dimensions (the last
    `concat_dims` of which are the neuron dimensions), and returns a single tensor of size
    [item, neurons].

    Examples:
        With 3 axis (e.g. batch, pos, neuron), concatenating the last 2 dimensions:

        >>> import torch
        >>> input = torch.randn(3, 4, 5)
        >>> res = reshape_concat_last_dimensions(input, 2)
        >>> res.shape
        torch.Size([3, 20])

        With 4 axis (e.g. batch, pos, head_idx, neuron), concatenating the last 3 dimensions:

        >>> input = torch.rand(2, 3, 4, 5)
        >>> res = reshape_concat_last_dimensions(input, 3)
        >>> res.shape
        torch.Size([2, 60])

    Args:
        batch_activations: Input Activation Store Batch
        concat_dims: Number of dimensions to concatenate

    Returns:
        Single Tensor of Activation Store Items
    """
    neurons = reduce(lambda x, y: x * y, batch_activations.shape[-concat_dims:])
    items = reduce(lambda x, y: x * y, batch_activations.shape[:-concat_dims])

    return batch_activations.reshape(items, neurons)

reshape_to_last_dimension(batch_activations) ¤

Reshape to Last Dimension.

Takes a tensor of activation vectors, with arbitrary numbers of dimensions (the last of which is the neurons dimension), and returns a single tensor of size [item, neurons].

Examples:

With 2 axis (e.g. pos neuron):

>>> import torch
>>> input = torch.rand(3, 100)
>>> res = reshape_to_last_dimension(input)
>>> res.shape
torch.Size([3, 100])

With 3 axis (e.g. batch, pos, neuron):

>>> input = torch.randn(3, 3, 100)
>>> res = reshape_to_last_dimension(input)
>>> res.shape
torch.Size([9, 100])

With 4 axis (e.g. batch, pos, head_idx, neuron)

>>> input = torch.rand(3, 3, 3, 100)
>>> res = reshape_to_last_dimension(input)
>>> res.shape
torch.Size([27, 100])

Parameters:

Name Type Description Default
batch_activations Float[Tensor, names(ANY)]

Input Activation Store Batch

required

Returns:

Type Description
Float[Tensor, names(STORE_BATCH, INPUT_OUTPUT_FEATURE)]

Single Tensor of Activation Store Items

Source code in sparse_autoencoder/source_model/reshape_activations.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def reshape_to_last_dimension(
    batch_activations: Float[Tensor, Axis.names(Axis.ANY)],
) -> Float[Tensor, Axis.names(Axis.STORE_BATCH, Axis.INPUT_OUTPUT_FEATURE)]:
    """Reshape to Last Dimension.

    Takes a tensor of activation vectors, with arbitrary numbers of dimensions (the last of which is
    the neurons dimension), and returns a single tensor of size [item, neurons].

    Examples:
        With 2 axis (e.g. pos neuron):

        >>> import torch
        >>> input = torch.rand(3, 100)
        >>> res = reshape_to_last_dimension(input)
        >>> res.shape
        torch.Size([3, 100])

        With 3 axis (e.g. batch, pos, neuron):

        >>> input = torch.randn(3, 3, 100)
        >>> res = reshape_to_last_dimension(input)
        >>> res.shape
        torch.Size([9, 100])

        With 4 axis (e.g. batch, pos, head_idx, neuron)

        >>> input = torch.rand(3, 3, 3, 100)
        >>> res = reshape_to_last_dimension(input)
        >>> res.shape
        torch.Size([27, 100])

    Args:
        batch_activations: Input Activation Store Batch

    Returns:
        Single Tensor of Activation Store Items
    """
    return rearrange(batch_activations, "... input_output_feature -> (...) input_output_feature")