Skip to content

Neuron activity metric¤

Neuron activity metric.

NeuronActivityMetric ¤

Bases: Metric

Neuron activity metric.

Example

With a single component and a horizon of 2 activations, the metric will return nothing after the first activation is added and then computed, and then return the number of dead neurons after the second activation is added (with update). The breakdown by component isn't shown here as there is just one component.

metric = NeuronActivityMetric(num_learned_features=3) learned_activations = torch.tensor([ ... [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons) ... [0., 0., 0.] # Batch 2 (single component): learned features (0 active neuron) ... ]) metric.forward(learned_activations) tensor(1)

Source code in sparse_autoencoder/metrics/train/neuron_activity.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
class NeuronActivityMetric(Metric):
    """Neuron activity metric.

    Example:
        With a single component and a horizon of 2 activations, the metric will return nothing
        after the first activation is added and then computed, and then return the number of dead
        neurons after the second activation is added (with update). The breakdown by component isn't
        shown here as there is just one component.

        >>> metric = NeuronActivityMetric(num_learned_features=3)
        >>> learned_activations = torch.tensor([
        ...     [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons)
        ...     [0., 0., 0.]  # Batch 2 (single component): learned features (0 active neuron)
        ... ])
        >>> metric.forward(learned_activations)
        tensor(1)
    """

    # Torchmetrics settings
    is_differentiable: bool | None = False
    full_state_update: bool | None = True
    plot_lower_bound: float | None = 0.0

    # Metric settings
    _threshold_is_dead_portion_fires: NonNegativeFloat

    # State
    neuron_fired_count: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
    num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM]

    @validate_call
    def __init__(
        self,
        num_learned_features: PositiveInt,
        num_components: PositiveInt | None = None,
        threshold_is_dead_portion_fires: Annotated[float, Field(strict=True, ge=0, le=1)] = 0.0,
    ) -> None:
        """Initialise the metric.

        Args:
            num_learned_features: Number of learned features.
            num_components: Number of components.
            threshold_is_dead_portion_fires: Thresholds for counting a neuron as dead (portion of
                activation vectors that it fires for must be less than or equal to this number).
                Commonly used values are 0.0, 1e-5 and 1e-6.
        """
        super().__init__()
        self._threshold_is_dead_portion_fires = threshold_is_dead_portion_fires

        self.add_state(
            "neuron_fired_count",
            default=torch.zeros(
                shape_with_optional_dimensions(num_components, num_learned_features),
                dtype=torch.float,  # Float is needed for dist reduce to work
            ),
            dist_reduce_fx="sum",
        )

        self.add_state(
            "num_activation_vectors",
            default=torch.tensor(0, dtype=torch.int64),
            dist_reduce_fx="sum",
        )

    def update(
        self,
        learned_activations: Float[
            Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
        ],
        **kwargs: Any,  # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
    ) -> None:
        """Update the metric state.

        Args:
            learned_activations: The learned activations.
            **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
        """
        # Increment the counter of activations seen since the last compute step
        self.num_activation_vectors += learned_activations.shape[0]

        # Count the number of active neurons in the batch
        neuron_has_fired: Bool[
            Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
        ] = torch.gt(learned_activations, 0)

        self.neuron_fired_count += neuron_has_fired.sum(dim=0, dtype=torch.float)

    def compute(self) -> Int64[Tensor, Axis.COMPONENT_OPTIONAL]:
        """Compute the metric.

        Note that torchmetrics converts shape `[0]` tensors into scalars (shape `0`).
        """
        threshold_activations: Float[Tensor, Axis.SINGLE_ITEM] = (
            self._threshold_is_dead_portion_fires * self.num_activation_vectors
        )

        return torch.sum(
            self.neuron_fired_count <= threshold_activations, dim=-1, dtype=torch.int64
        )

__init__(num_learned_features, num_components=None, threshold_is_dead_portion_fires=0.0) ¤

Initialise the metric.

Parameters:

Name Type Description Default
num_learned_features PositiveInt

Number of learned features.

required
num_components PositiveInt | None

Number of components.

None
threshold_is_dead_portion_fires Annotated[float, Field(strict=True, ge=0, le=1)]

Thresholds for counting a neuron as dead (portion of activation vectors that it fires for must be less than or equal to this number). Commonly used values are 0.0, 1e-5 and 1e-6.

0.0
Source code in sparse_autoencoder/metrics/train/neuron_activity.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@validate_call
def __init__(
    self,
    num_learned_features: PositiveInt,
    num_components: PositiveInt | None = None,
    threshold_is_dead_portion_fires: Annotated[float, Field(strict=True, ge=0, le=1)] = 0.0,
) -> None:
    """Initialise the metric.

    Args:
        num_learned_features: Number of learned features.
        num_components: Number of components.
        threshold_is_dead_portion_fires: Thresholds for counting a neuron as dead (portion of
            activation vectors that it fires for must be less than or equal to this number).
            Commonly used values are 0.0, 1e-5 and 1e-6.
    """
    super().__init__()
    self._threshold_is_dead_portion_fires = threshold_is_dead_portion_fires

    self.add_state(
        "neuron_fired_count",
        default=torch.zeros(
            shape_with_optional_dimensions(num_components, num_learned_features),
            dtype=torch.float,  # Float is needed for dist reduce to work
        ),
        dist_reduce_fx="sum",
    )

    self.add_state(
        "num_activation_vectors",
        default=torch.tensor(0, dtype=torch.int64),
        dist_reduce_fx="sum",
    )

compute() ¤

Compute the metric.

Note that torchmetrics converts shape [0] tensors into scalars (shape 0).

Source code in sparse_autoencoder/metrics/train/neuron_activity.py
101
102
103
104
105
106
107
108
109
110
111
112
def compute(self) -> Int64[Tensor, Axis.COMPONENT_OPTIONAL]:
    """Compute the metric.

    Note that torchmetrics converts shape `[0]` tensors into scalars (shape `0`).
    """
    threshold_activations: Float[Tensor, Axis.SINGLE_ITEM] = (
        self._threshold_is_dead_portion_fires * self.num_activation_vectors
    )

    return torch.sum(
        self.neuron_fired_count <= threshold_activations, dim=-1, dtype=torch.int64
    )

update(learned_activations, **kwargs) ¤

Update the metric state.

Parameters:

Name Type Description Default
learned_activations Float[Tensor, names(BATCH, COMPONENT_OPTIONAL, LEARNT_FEATURE)]

The learned activations.

required
**kwargs Any

Ignored keyword arguments (to allow use with other metrics in a collection).

{}
Source code in sparse_autoencoder/metrics/train/neuron_activity.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def update(
    self,
    learned_activations: Float[
        Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
    ],
    **kwargs: Any,  # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
) -> None:
    """Update the metric state.

    Args:
        learned_activations: The learned activations.
        **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
    """
    # Increment the counter of activations seen since the last compute step
    self.num_activation_vectors += learned_activations.shape[0]

    # Count the number of active neurons in the batch
    neuron_has_fired: Bool[
        Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
    ] = torch.gt(learned_activations, 0)

    self.neuron_fired_count += neuron_has_fired.sum(dim=0, dtype=torch.float)