Reconstruction score metric¤
Reconstruction score metric.
ReconstructionScoreMetric
¤
Bases: Metric
Model reconstruction score.
Creates a score that measures how well the model can reconstruct the data.
Example
metric = ReconstructionScoreMetric(num_components=1) source_model_loss=torch.tensor([2.0, 2.0, 2.0]) source_model_loss_with_reconstruction=torch.tensor([3.0, 3.0, 3.0]) source_model_loss_with_zero_ablation=torch.tensor([5.0, 5.0, 5.0]) metric.forward( ... source_model_loss=source_model_loss, ... source_model_loss_with_reconstruction=source_model_loss_with_reconstruction, ... source_model_loss_with_zero_ablation=source_model_loss_with_zero_ablation ... ) tensor(0.6667)
Source code in sparse_autoencoder/metrics/validate/reconstruction_score.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
|
__init__(num_components=1)
¤
Initialise the metric.
Source code in sparse_autoencoder/metrics/validate/reconstruction_score.py
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
|
compute()
¤
Compute the metric.
Source code in sparse_autoencoder/metrics/validate/reconstruction_score.py
93 94 95 96 97 98 99 100 101 102 103 104 105 |
|
update(source_model_loss, source_model_loss_with_reconstruction, source_model_loss_with_zero_ablation, component_idx=0)
¤
Update the metric state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source_model_loss |
Float[Tensor, COMPONENT_OPTIONAL]
|
Loss with no changes to the source model. |
required |
source_model_loss_with_reconstruction |
Float[Tensor, COMPONENT_OPTIONAL]
|
Loss with SAE reconstruction. |
required |
source_model_loss_with_zero_ablation |
Float[Tensor, COMPONENT_OPTIONAL]
|
Loss with zero ablation. |
required |
component_idx |
int
|
Component idx. |
0
|
Source code in sparse_autoencoder/metrics/validate/reconstruction_score.py
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
|