Skip to content

Tensor Axis Types¤

Tensor Axis Types.

Axis ¤

Bases: LowercaseStrEnum

Tensor axis names.

Used to annotate tensor types.

Example

When used directly it prints a string:

print(Axis.INPUT_OUTPUT_FEATURE) input_output_feature

The primary use is to annotate tensor types:

from jaxtyping import Float from torch import Tensor from typing import TypeAlias batch: TypeAlias = Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] print(batch)

You can also join multiple axis together to represent the dimensions of a tensor:

print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)) batch input_output_feature

Source code in sparse_autoencoder/tensor_types.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class Axis(LowercaseStrEnum):
    """Tensor axis names.

    Used to annotate tensor types.

    Example:
        When used directly it prints a string:

        >>> print(Axis.INPUT_OUTPUT_FEATURE)
        input_output_feature

        The primary use is to annotate tensor types:

        >>> from jaxtyping import Float
        >>> from torch import Tensor
        >>> from typing import TypeAlias
        >>> batch: TypeAlias = Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)]
        >>> print(batch)
        <class 'jaxtyping.Float[Tensor, 'batch input_output_feature']'>

        You can also join multiple axis together to represent the dimensions of a tensor:

        >>> print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE))
        batch input_output_feature
    """

    # Component idx
    COMPONENT = auto()
    """Component index."""

    COMPONENT_OPTIONAL = "*component"
    """Optional component index."""

    # Batches
    SOURCE_DATA_BATCH = auto()
    """Batch of prompts used to generate source model activations."""

    BATCH = auto()
    """Batch of items that the SAE is being trained on."""

    STORE_BATCH = auto()
    """Batch of items to be written to the store."""

    ITEMS = auto()
    """Arbitrary number of items."""

    # Features
    INPUT_OUTPUT_FEATURE = auto()
    """Input or output feature (e.g. feature in activation vector from source model)."""

    LEARNT_FEATURE = auto()
    """Learn feature (e.g. feature in learnt activation vector)."""

    DEAD_FEATURE = auto()
    """Dead feature."""

    ALIVE_FEATURE = auto()
    """Alive feature."""

    # Feature indices
    INPUT_OUTPUT_FEATURE_IDX = auto()
    """Input or output feature index."""

    LEARNT_FEATURE_IDX = auto()
    """Learn feature index."""

    # Other
    POSITION = auto()
    """Token position."""

    SINGLE_ITEM = ""
    """Single item axis."""

    ANY = "..."
    """Any number of axis."""

    @staticmethod
    def names(*axis: "Axis") -> str:
        """Join multiple axis together, to represent the dimensions of a tensor.

        Example:
            >>> print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE))
            batch input_output_feature

        Args:
            *axis: Axis to join.

        Returns:
            Joined axis string.
        """
        return " ".join(a.value for a in axis)

ALIVE_FEATURE = auto() class-attribute instance-attribute ¤

Alive feature.

ANY = '...' class-attribute instance-attribute ¤

Any number of axis.

BATCH = auto() class-attribute instance-attribute ¤

Batch of items that the SAE is being trained on.

COMPONENT = auto() class-attribute instance-attribute ¤

Component index.

COMPONENT_OPTIONAL = '*component' class-attribute instance-attribute ¤

Optional component index.

DEAD_FEATURE = auto() class-attribute instance-attribute ¤

Dead feature.

INPUT_OUTPUT_FEATURE = auto() class-attribute instance-attribute ¤

Input or output feature (e.g. feature in activation vector from source model).

INPUT_OUTPUT_FEATURE_IDX = auto() class-attribute instance-attribute ¤

Input or output feature index.

ITEMS = auto() class-attribute instance-attribute ¤

Arbitrary number of items.

LEARNT_FEATURE = auto() class-attribute instance-attribute ¤

Learn feature (e.g. feature in learnt activation vector).

LEARNT_FEATURE_IDX = auto() class-attribute instance-attribute ¤

Learn feature index.

POSITION = auto() class-attribute instance-attribute ¤

Token position.

SINGLE_ITEM = '' class-attribute instance-attribute ¤

Single item axis.

SOURCE_DATA_BATCH = auto() class-attribute instance-attribute ¤

Batch of prompts used to generate source model activations.

STORE_BATCH = auto() class-attribute instance-attribute ¤

Batch of items to be written to the store.

names(*axis) staticmethod ¤

Join multiple axis together, to represent the dimensions of a tensor.

Example

print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)) batch input_output_feature

Parameters:

Name Type Description Default
*axis Axis

Axis to join.

()

Returns:

Type Description
str

Joined axis string.

Source code in sparse_autoencoder/tensor_types.py
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
@staticmethod
def names(*axis: "Axis") -> str:
    """Join multiple axis together, to represent the dimensions of a tensor.

    Example:
        >>> print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE))
        batch input_output_feature

    Args:
        *axis: Axis to join.

    Returns:
        Joined axis string.
    """
    return " ".join(a.value for a in axis)