Skip to content

Reconstruction score metric¤

Reconstruction score metric.

ReconstructionScoreMetric ¤

Bases: Metric

Model reconstruction score.

Creates a score that measures how well the model can reconstruct the data.

\[ \begin{align*} v &= \text{number of validation items} \\ l \in{\mathbb{R}^v} &= \text{loss with no changes to the source model} \\ l_\text{recon} \in{\mathbb{R}^v} &= \text{loss with reconstruction} \\ l_\text{zero} \in{\mathbb{R}^v} &= \text{loss with zero ablation} \\ s &= \text{reconstruction score} \\ s_\text{itemwise} &= \frac{l_\text{zero} - l_\text{recon}}{l_\text{zero} - l} \\ s &= \sum_{i=1}^v s_\text{itemwise} / v \end{align*} \]
Example

metric = ReconstructionScoreMetric(num_components=1) source_model_loss=torch.tensor([2.0, 2.0, 2.0]) source_model_loss_with_reconstruction=torch.tensor([3.0, 3.0, 3.0]) source_model_loss_with_zero_ablation=torch.tensor([5.0, 5.0, 5.0]) metric.forward( ... source_model_loss=source_model_loss, ... source_model_loss_with_reconstruction=source_model_loss_with_reconstruction, ... source_model_loss_with_zero_ablation=source_model_loss_with_zero_ablation ... ) tensor(0.6667)

Source code in sparse_autoencoder/metrics/validate/reconstruction_score.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class ReconstructionScoreMetric(Metric):
    r"""Model reconstruction score.

    Creates a score that measures how well the model can reconstruct the data.

    $$
    \begin{align*}
        v &= \text{number of validation items} \\
        l \in{\mathbb{R}^v} &= \text{loss with no changes to the source model} \\
        l_\text{recon} \in{\mathbb{R}^v} &= \text{loss with reconstruction} \\
        l_\text{zero} \in{\mathbb{R}^v} &= \text{loss with zero ablation} \\
        s &= \text{reconstruction score} \\
        s_\text{itemwise} &= \frac{l_\text{zero} - l_\text{recon}}{l_\text{zero} - l} \\
        s &= \sum_{i=1}^v s_\text{itemwise} / v
    \end{align*}
    $$

    Example:
        >>> metric = ReconstructionScoreMetric(num_components=1)
        >>> source_model_loss=torch.tensor([2.0, 2.0, 2.0])
        >>> source_model_loss_with_reconstruction=torch.tensor([3.0, 3.0, 3.0])
        >>> source_model_loss_with_zero_ablation=torch.tensor([5.0, 5.0, 5.0])
        >>> metric.forward(
        ...     source_model_loss=source_model_loss,
        ...     source_model_loss_with_reconstruction=source_model_loss_with_reconstruction,
        ...     source_model_loss_with_zero_ablation=source_model_loss_with_zero_ablation
        ... )
        tensor(0.6667)
    """

    # Torchmetrics settings
    is_differentiable: bool | None = False
    full_state_update: bool | None = True

    # State
    source_model_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL]
    source_model_loss_with_zero_ablation: Float[Tensor, Axis.COMPONENT_OPTIONAL]
    source_model_loss_with_reconstruction: Float[Tensor, Axis.COMPONENT_OPTIONAL]
    num_activation_vectors: Int64[Tensor, Axis.SINGLE_ITEM]

    @validate_call
    def __init__(self, num_components: PositiveInt = 1) -> None:
        """Initialise the metric."""
        super().__init__()

        self.add_state(
            "source_model_loss", default=torch.zeros(num_components), dist_reduce_fx="sum"
        )
        self.add_state(
            "source_model_loss_with_zero_ablation",
            default=torch.zeros(num_components),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "source_model_loss_with_reconstruction",
            default=torch.zeros(num_components),
            dist_reduce_fx="sum",
        )

    def update(
        self,
        source_model_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL],
        source_model_loss_with_reconstruction: Float[Tensor, Axis.COMPONENT_OPTIONAL],
        source_model_loss_with_zero_ablation: Float[Tensor, Axis.COMPONENT_OPTIONAL],
        component_idx: int = 0,
    ) -> None:
        """Update the metric state.

        Args:
            source_model_loss: Loss with no changes to the source model.
            source_model_loss_with_reconstruction: Loss with SAE reconstruction.
            source_model_loss_with_zero_ablation: Loss with zero ablation.
            component_idx: Component idx.
        """
        self.source_model_loss[component_idx] += source_model_loss.sum()
        self.source_model_loss_with_zero_ablation[
            component_idx
        ] += source_model_loss_with_zero_ablation.sum()
        self.source_model_loss_with_reconstruction[
            component_idx
        ] += source_model_loss_with_reconstruction.sum()

    def compute(
        self,
    ) -> Float[Tensor, Axis.COMPONENT_OPTIONAL]:
        """Compute the metric."""
        zero_ablate_loss_minus_reconstruction_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL] = (
            self.source_model_loss_with_zero_ablation - self.source_model_loss_with_reconstruction
        )

        zero_ablate_loss_minus_default_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL] = (
            self.source_model_loss_with_zero_ablation - self.source_model_loss
        )

        return zero_ablate_loss_minus_reconstruction_loss / zero_ablate_loss_minus_default_loss

__init__(num_components=1) ¤

Initialise the metric.

Source code in sparse_autoencoder/metrics/validate/reconstruction_score.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@validate_call
def __init__(self, num_components: PositiveInt = 1) -> None:
    """Initialise the metric."""
    super().__init__()

    self.add_state(
        "source_model_loss", default=torch.zeros(num_components), dist_reduce_fx="sum"
    )
    self.add_state(
        "source_model_loss_with_zero_ablation",
        default=torch.zeros(num_components),
        dist_reduce_fx="sum",
    )
    self.add_state(
        "source_model_loss_with_reconstruction",
        default=torch.zeros(num_components),
        dist_reduce_fx="sum",
    )

compute() ¤

Compute the metric.

Source code in sparse_autoencoder/metrics/validate/reconstruction_score.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def compute(
    self,
) -> Float[Tensor, Axis.COMPONENT_OPTIONAL]:
    """Compute the metric."""
    zero_ablate_loss_minus_reconstruction_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL] = (
        self.source_model_loss_with_zero_ablation - self.source_model_loss_with_reconstruction
    )

    zero_ablate_loss_minus_default_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL] = (
        self.source_model_loss_with_zero_ablation - self.source_model_loss
    )

    return zero_ablate_loss_minus_reconstruction_loss / zero_ablate_loss_minus_default_loss

update(source_model_loss, source_model_loss_with_reconstruction, source_model_loss_with_zero_ablation, component_idx=0) ¤

Update the metric state.

Parameters:

Name Type Description Default
source_model_loss Float[Tensor, COMPONENT_OPTIONAL]

Loss with no changes to the source model.

required
source_model_loss_with_reconstruction Float[Tensor, COMPONENT_OPTIONAL]

Loss with SAE reconstruction.

required
source_model_loss_with_zero_ablation Float[Tensor, COMPONENT_OPTIONAL]

Loss with zero ablation.

required
component_idx int

Component idx.

0
Source code in sparse_autoencoder/metrics/validate/reconstruction_score.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def update(
    self,
    source_model_loss: Float[Tensor, Axis.COMPONENT_OPTIONAL],
    source_model_loss_with_reconstruction: Float[Tensor, Axis.COMPONENT_OPTIONAL],
    source_model_loss_with_zero_ablation: Float[Tensor, Axis.COMPONENT_OPTIONAL],
    component_idx: int = 0,
) -> None:
    """Update the metric state.

    Args:
        source_model_loss: Loss with no changes to the source model.
        source_model_loss_with_reconstruction: Loss with SAE reconstruction.
        source_model_loss_with_zero_ablation: Loss with zero ablation.
        component_idx: Component idx.
    """
    self.source_model_loss[component_idx] += source_model_loss.sum()
    self.source_model_loss_with_zero_ablation[
        component_idx
    ] += source_model_loss_with_zero_ablation.sum()
    self.source_model_loss_with_reconstruction[
        component_idx
    ] += source_model_loss_with_reconstruction.sum()