Skip to content

Neuron fired count metric¤

Neuron fired count metric.

NeuronFiredCountMetric ¤

Bases: Metric

Neuron activity metric.

Example

metric = NeuronFiredCountMetric(num_learned_features=3) learned_activations = torch.tensor([ ... [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons) ... [0., 0., 0.] # Batch 2 (single component): learned features (0 active neuron) ... ]) metric.forward(learned_activations) tensor([1, 0, 1])

Source code in sparse_autoencoder/metrics/train/neuron_fired_count.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class NeuronFiredCountMetric(Metric):
    """Neuron activity metric.

    Example:
        >>> metric = NeuronFiredCountMetric(num_learned_features=3)
        >>> learned_activations = torch.tensor([
        ...     [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons)
        ...     [0., 0., 0.]  # Batch 2 (single component): learned features (0 active neuron)
        ... ])
        >>> metric.forward(learned_activations)
        tensor([1, 0, 1])
    """

    # Torchmetrics settings
    is_differentiable: bool | None = True
    full_state_update: bool | None = True
    plot_lower_bound: float | None = 0.0

    # State
    neuron_fired_count: Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]

    @validate_call
    def __init__(
        self,
        num_learned_features: PositiveInt,
        num_components: PositiveInt | None = None,
    ) -> None:
        """Initialise the metric.

        Args:
            num_learned_features: Number of learned features.
            num_components: Number of components.
        """
        super().__init__()
        self.add_state(
            "neuron_fired_count",
            default=torch.zeros(
                shape_with_optional_dimensions(num_components, num_learned_features),
                dtype=torch.float,  # Float is needed for dist reduce to work
            ),
            dist_reduce_fx="sum",
        )

    def update(
        self,
        learned_activations: Float[
            Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
        ],
        **kwargs: Any,  # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
    ) -> None:
        """Update the metric state.

        Args:
            learned_activations: The learned activations.
            **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
        """
        neuron_has_fired: Bool[
            Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
        ] = torch.gt(learned_activations, 0)

        self.neuron_fired_count += neuron_has_fired.sum(dim=0, dtype=torch.float)

    def compute(self) -> Int[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]:
        """Compute the metric."""
        return self.neuron_fired_count.to(dtype=torch.int64)

__init__(num_learned_features, num_components=None) ¤

Initialise the metric.

Parameters:

Name Type Description Default
num_learned_features PositiveInt

Number of learned features.

required
num_components PositiveInt | None

Number of components.

None
Source code in sparse_autoencoder/metrics/train/neuron_fired_count.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@validate_call
def __init__(
    self,
    num_learned_features: PositiveInt,
    num_components: PositiveInt | None = None,
) -> None:
    """Initialise the metric.

    Args:
        num_learned_features: Number of learned features.
        num_components: Number of components.
    """
    super().__init__()
    self.add_state(
        "neuron_fired_count",
        default=torch.zeros(
            shape_with_optional_dimensions(num_components, num_learned_features),
            dtype=torch.float,  # Float is needed for dist reduce to work
        ),
        dist_reduce_fx="sum",
    )

compute() ¤

Compute the metric.

Source code in sparse_autoencoder/metrics/train/neuron_fired_count.py
76
77
78
def compute(self) -> Int[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]:
    """Compute the metric."""
    return self.neuron_fired_count.to(dtype=torch.int64)

update(learned_activations, **kwargs) ¤

Update the metric state.

Parameters:

Name Type Description Default
learned_activations Float[Tensor, names(BATCH, COMPONENT_OPTIONAL, LEARNT_FEATURE)]

The learned activations.

required
**kwargs Any

Ignored keyword arguments (to allow use with other metrics in a collection).

{}
Source code in sparse_autoencoder/metrics/train/neuron_fired_count.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def update(
    self,
    learned_activations: Float[
        Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
    ],
    **kwargs: Any,  # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
) -> None:
    """Update the metric state.

    Args:
        learned_activations: The learned activations.
        **kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
    """
    neuron_has_fired: Bool[
        Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
    ] = torch.gt(learned_activations, 0)

    self.neuron_fired_count += neuron_has_fired.sum(dim=0, dtype=torch.float)