Tensor shape utilities¤
Tensor shape utilities.
shape_with_optional_dimensions(*shape)
¤
Create a shape from a tuple of optional dimensions.
Motivation
By default PyTorch tensor shapes will error if you set an axis to None. This allows
you to set that size and then the resulting output simply removes that axis.
Examples:
>>> shape_with_optional_dimensions(1, 2, 3)
(1, 2, 3)
>>> shape_with_optional_dimensions(1, None, 3)
(1, 3)
>>> shape_with_optional_dimensions(1, None, None)
(1,)
>>> shape_with_optional_dimensions(None, None, None)
()
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*shape |
int | None
|
Axis sizes, with |
()
|
Returns:
| Type | Description |
|---|---|
tuple[int, ...]
|
Axis sizes. |
Source code in sparse_autoencoder/utils/tensor_shape.py
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | |