Skip to content

Tensor shape utilities¤

Tensor shape utilities.

shape_with_optional_dimensions(*shape) ¤

Create a shape from a tuple of optional dimensions.

Motivation

By default PyTorch tensor shapes will error if you set an axis to None. This allows you to set that size and then the resulting output simply removes that axis.

Examples:

>>> shape_with_optional_dimensions(1, 2, 3)
(1, 2, 3)
>>> shape_with_optional_dimensions(1, None, 3)
(1, 3)
>>> shape_with_optional_dimensions(1, None, None)
(1,)
>>> shape_with_optional_dimensions(None, None, None)
()

Parameters:

Name Type Description Default
*shape int | None

Axis sizes, with None representing an optional axis.

()

Returns:

Type Description
tuple[int, ...]

Axis sizes.

Source code in sparse_autoencoder/utils/tensor_shape.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def shape_with_optional_dimensions(*shape: int | None) -> tuple[int, ...]:
    """Create a shape from a tuple of optional dimensions.

    Motivation:
        By default PyTorch tensor shapes will error if you set an axis to `None`. This allows
        you to set that size and then the resulting output simply removes that axis.

    Examples:
        >>> shape_with_optional_dimensions(1, 2, 3)
        (1, 2, 3)

        >>> shape_with_optional_dimensions(1, None, 3)
        (1, 3)

        >>> shape_with_optional_dimensions(1, None, None)
        (1,)

        >>> shape_with_optional_dimensions(None, None, None)
        ()

    Args:
        *shape: Axis sizes, with `None` representing an optional axis.

    Returns:
        Axis sizes.
    """
    return tuple(dimension for dimension in shape if dimension is not None)