Tensor Axis Types¤
Tensor Axis Types.
Axis
¤
Bases: LowercaseStrEnum
Tensor axis names.
Used to annotate tensor types.
Example
When used directly it prints a string:
print(Axis.INPUT_OUTPUT_FEATURE) input_output_feature
The primary use is to annotate tensor types:
from jaxtyping import Float from torch import Tensor from typing import TypeAlias batch: TypeAlias = Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] print(batch)
You can also join multiple axis together to represent the dimensions of a tensor:
print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)) batch input_output_feature
Source code in sparse_autoencoder/tensor_types.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
|
ALIVE_FEATURE = auto()
class-attribute
instance-attribute
¤
Alive feature.
ANY = '...'
class-attribute
instance-attribute
¤
Any number of axis.
BATCH = auto()
class-attribute
instance-attribute
¤
Batch of items that the SAE is being trained on.
COMPONENT = auto()
class-attribute
instance-attribute
¤
Component index.
COMPONENT_OPTIONAL = '*component'
class-attribute
instance-attribute
¤
Optional component index.
DEAD_FEATURE = auto()
class-attribute
instance-attribute
¤
Dead feature.
INPUT_OUTPUT_FEATURE = auto()
class-attribute
instance-attribute
¤
Input or output feature (e.g. feature in activation vector from source model).
INPUT_OUTPUT_FEATURE_IDX = auto()
class-attribute
instance-attribute
¤
Input or output feature index.
ITEMS = auto()
class-attribute
instance-attribute
¤
Arbitrary number of items.
LEARNT_FEATURE = auto()
class-attribute
instance-attribute
¤
Learn feature (e.g. feature in learnt activation vector).
LEARNT_FEATURE_IDX = auto()
class-attribute
instance-attribute
¤
Learn feature index.
POSITION = auto()
class-attribute
instance-attribute
¤
Token position.
SINGLE_ITEM = ''
class-attribute
instance-attribute
¤
Single item axis.
SOURCE_DATA_BATCH = auto()
class-attribute
instance-attribute
¤
Batch of prompts used to generate source model activations.
STORE_BATCH = auto()
class-attribute
instance-attribute
¤
Batch of items to be written to the store.
names(*axis)
staticmethod
¤
Join multiple axis together, to represent the dimensions of a tensor.
Example
print(Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)) batch input_output_feature
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*axis |
Axis
|
Axis to join. |
()
|
Returns:
Type | Description |
---|---|
str
|
Joined axis string. |
Source code in sparse_autoencoder/tensor_types.py
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
|