Bases: Metric
Learned activations L0 norm metric.
The L0 norm is the number of non-zero elements in a learned activation vector, averaged over the
number of activation vectors.
Examples:
>>> metric = L0NormMetric()
>>> learned_activations = torch.tensor([
... [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons)
... [0., 1., 0.] # Batch 2 (single component): learned features (1 active neuron)
... ])
>>> metric.forward(learned_activations)
tensor(1.5000)
With 2 components, the metric will return the average number of active (non-zero)
neurons as a 1d tensor.
>>> metric = L0NormMetric(num_components=2)
>>> learned_activations = torch.tensor([
... [ # Batch 1
... [1., 0., 1.], # Component 1: learned features (2 active neurons)
... [1., 0., 1.] # Component 2: learned features (2 active neurons)
... ],
... [ # Batch 2
... [0., 1., 0.], # Component 1: learned features (1 active neuron)
... [1., 0., 1.] # Component 2: learned features (2 active neurons)
... ]
... ])
>>> metric.forward(learned_activations)
tensor([1.5000, 2.0000])
Source code in sparse_autoencoder/metrics/train/l0_norm.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99 | class L0NormMetric(Metric):
"""Learned activations L0 norm metric.
The L0 norm is the number of non-zero elements in a learned activation vector, averaged over the
number of activation vectors.
Examples:
>>> metric = L0NormMetric()
>>> learned_activations = torch.tensor([
... [1., 0., 1.], # Batch 1 (single component): learned features (2 active neurons)
... [0., 1., 0.] # Batch 2 (single component): learned features (1 active neuron)
... ])
>>> metric.forward(learned_activations)
tensor(1.5000)
With 2 components, the metric will return the average number of active (non-zero)
neurons as a 1d tensor.
>>> metric = L0NormMetric(num_components=2)
>>> learned_activations = torch.tensor([
... [ # Batch 1
... [1., 0., 1.], # Component 1: learned features (2 active neurons)
... [1., 0., 1.] # Component 2: learned features (2 active neurons)
... ],
... [ # Batch 2
... [0., 1., 0.], # Component 1: learned features (1 active neuron)
... [1., 0., 1.] # Component 2: learned features (2 active neurons)
... ]
... ])
>>> metric.forward(learned_activations)
tensor([1.5000, 2.0000])
"""
# Torchmetrics settings
is_differentiable: bool | None = False
full_state_update: bool | None = False
plot_lower_bound: float | None = 0.0
# State
active_neurons_count: Float[Tensor, Axis.COMPONENT_OPTIONAL]
num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM]
@validate_call
def __init__(self, num_components: PositiveInt | None = None) -> None:
"""Initialize the metric."""
super().__init__()
self.add_state(
"active_neurons_count",
default=torch.zeros(shape_with_optional_dimensions(num_components), dtype=torch.float),
dist_reduce_fx="sum", # Float is needed for dist reduce to work
)
self.add_state(
"num_activation_vectors",
default=torch.tensor(0, dtype=torch.int64),
dist_reduce_fx="sum",
)
def update(
self,
learned_activations: Float[
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
],
**kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
) -> None:
"""Update the metric state.
Args:
learned_activations: The learned activations.
**kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
"""
self.num_activation_vectors += learned_activations.shape[0]
self.active_neurons_count += torch.count_nonzero(learned_activations, dim=-1).sum(
dim=0, dtype=torch.int64
)
def compute(
self,
) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)]:
"""Compute the metric.
Note that torchmetrics converts shape `[0]` tensors into scalars (shape `0`).
"""
return self.active_neurons_count / self.num_activation_vectors
|
__init__(num_components=None)
Initialize the metric.
Source code in sparse_autoencoder/metrics/train/l0_norm.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71 | @validate_call
def __init__(self, num_components: PositiveInt | None = None) -> None:
"""Initialize the metric."""
super().__init__()
self.add_state(
"active_neurons_count",
default=torch.zeros(shape_with_optional_dimensions(num_components), dtype=torch.float),
dist_reduce_fx="sum", # Float is needed for dist reduce to work
)
self.add_state(
"num_activation_vectors",
default=torch.tensor(0, dtype=torch.int64),
dist_reduce_fx="sum",
)
|
compute()
Compute the metric.
Note that torchmetrics converts shape [0]
tensors into scalars (shape 0
).
Source code in sparse_autoencoder/metrics/train/l0_norm.py
| def compute(
self,
) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)]:
"""Compute the metric.
Note that torchmetrics converts shape `[0]` tensors into scalars (shape `0`).
"""
return self.active_neurons_count / self.num_activation_vectors
|
update(learned_activations, **kwargs)
Update the metric state.
Parameters:
Name |
Type |
Description |
Default |
learned_activations |
Float[Tensor, names(BATCH, COMPONENT_OPTIONAL, LEARNT_FEATURE)]
|
|
required
|
**kwargs |
Any
|
Ignored keyword arguments (to allow use with other metrics in a collection).
|
{}
|
Source code in sparse_autoencoder/metrics/train/l0_norm.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90 | def update(
self,
learned_activations: Float[
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
],
**kwargs: Any, # type: ignore # noqa: ARG002, ANN401 (allows combining with other metrics)
) -> None:
"""Update the metric state.
Args:
learned_activations: The learned activations.
**kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
"""
self.num_activation_vectors += learned_activations.shape[0]
self.active_neurons_count += torch.count_nonzero(learned_activations, dim=-1).sum(
dim=0, dtype=torch.int64
)
|