Skip to content

Data parallel utils¤

Data parallel utils.

DataParallelWithModelAttributes ¤

Bases: DataParallel[T], Generic[T]

Data parallel with access to underlying model attributes/methods.

Allows access to underlying model attributes/methods, which is not possible with the default DataParallel class. Based on: https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html

Example

from sparse_autoencoder import SparseAutoencoder, SparseAutoencoderConfig model = SparseAutoencoder(SparseAutoencoderConfig( ... n_input_features=2, ... n_learned_features=4, ... )) distributed_model = DataParallelWithModelAttributes(model) distributed_model.config.n_learned_features 4

Source code in sparse_autoencoder/utils/data_parallel.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class DataParallelWithModelAttributes(DataParallel[T], Generic[T]):
    """Data parallel with access to underlying model attributes/methods.

    Allows access to underlying model attributes/methods, which is not possible with the default
    `DataParallel` class. Based on:
    https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html

    Example:
        >>> from sparse_autoencoder import SparseAutoencoder, SparseAutoencoderConfig
        >>> model = SparseAutoencoder(SparseAutoencoderConfig(
        ...     n_input_features=2,
        ...     n_learned_features=4,
        ... ))
        >>> distributed_model = DataParallelWithModelAttributes(model)
        >>> distributed_model.config.n_learned_features
        4
    """

    def __getattr__(self, name: str) -> Any:  # noqa: ANN401
        """Allow access to underlying model attributes/methods.

        Args:
            name: Attribute/method name.

        Returns:
            Attribute value/method.
        """
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

__getattr__(name) ¤

Allow access to underlying model attributes/methods.

Parameters:

Name Type Description Default
name str

Attribute/method name.

required

Returns:

Type Description
Any

Attribute value/method.

Source code in sparse_autoencoder/utils/data_parallel.py
28
29
30
31
32
33
34
35
36
37
38
39
40
def __getattr__(self, name: str) -> Any:  # noqa: ANN401
    """Allow access to underlying model attributes/methods.

    Args:
        name: Attribute/method name.

    Returns:
        Attribute value/method.
    """
    try:
        return super().__getattr__(name)
    except AttributeError:
        return getattr(self.module, name)