Skip to content

Component slice tensor utils¤

Component slice tensor utils.

get_component_slice_tensor(input_tensor, n_dim_with_component, component_dim, component_idx) ¤

Get a slice of a tensor for a specific component.

Examples:

>>> import torch
>>> input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
>>> get_component_slice_tensor(input_tensor, 2, 1, 0)
tensor([1, 3, 5, 7])
>>> input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
>>> get_component_slice_tensor(input_tensor, 3, 1, 0)
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])

Parameters:

Name Type Description Default
input_tensor Tensor

Input tensor.

required
n_dim_with_component int

Number of dimensions in the input tensor with the component axis.

required
component_dim int

Dimension of the component axis.

required
component_idx int

Index of the component to get the slice for.

required

Returns:

Type Description
Tensor

Tensor slice.

Raises:

Type Description
ValueError

If the input tensor does not have the expected number of dimensions.

Source code in sparse_autoencoder/activation_resampler/utils/component_slice_tensor.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_component_slice_tensor(
    input_tensor: Tensor,
    n_dim_with_component: int,
    component_dim: int,
    component_idx: int,
) -> Tensor:
    """Get a slice of a tensor for a specific component.

    Examples:
        >>> import torch
        >>> input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
        >>> get_component_slice_tensor(input_tensor, 2, 1, 0)
        tensor([1, 3, 5, 7])

        >>> input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
        >>> get_component_slice_tensor(input_tensor, 3, 1, 0)
        tensor([[1, 2],
                [3, 4],
                [5, 6],
                [7, 8]])

    Args:
        input_tensor: Input tensor.
        n_dim_with_component: Number of dimensions in the input tensor with the component axis.
        component_dim: Dimension of the component axis.
        component_idx: Index of the component to get the slice for.

    Returns:
        Tensor slice.

    Raises:
        ValueError: If the input tensor does not have the expected number of dimensions.
    """
    if n_dim_with_component - 1 == input_tensor.ndim:
        return input_tensor

    if n_dim_with_component != input_tensor.ndim:
        error_message = (
            f"Cannot get component slice for tensor with {input_tensor.ndim} dimensions "
            f"and {n_dim_with_component} dimensions with component."
        )
        raise ValueError(error_message)

    # Create a tuple of slices for each dimension
    slice_tuple = tuple(
        component_idx if i == component_dim else slice(None) for i in range(input_tensor.ndim)
    )

    return input_tensor[slice_tuple]