Component slice tensor utils¤
Component slice tensor utils.
get_component_slice_tensor(input_tensor, n_dim_with_component, component_dim, component_idx)
¤
Get a slice of a tensor for a specific component.
Examples:
>>> import torch
>>> input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
>>> get_component_slice_tensor(input_tensor, 2, 1, 0)
tensor([1, 3, 5, 7])
>>> input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
>>> get_component_slice_tensor(input_tensor, 3, 1, 0)
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_tensor |
Tensor
|
Input tensor. |
required |
n_dim_with_component |
int
|
Number of dimensions in the input tensor with the component axis. |
required |
component_dim |
int
|
Dimension of the component axis. |
required |
component_idx |
int
|
Index of the component to get the slice for. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Tensor slice. |
Raises:
Type | Description |
---|---|
ValueError
|
If the input tensor does not have the expected number of dimensions. |
Source code in sparse_autoencoder/activation_resampler/utils/component_slice_tensor.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
|