Skip to content

Get the device that the model is on¤

Get the device that the model is on.

get_model_device(model) ¤

Get the device on which a PyTorch model is on.

Parameters:

Name Type Description Default
model Module | DataParallel | LightningModule

The PyTorch model.

required

Returns:

Type Description
device | None

The device ('cuda' or 'cpu') where the model is located.

Raises:

Type Description
ValueError

If the model has no parameters.

Source code in sparse_autoencoder/train/utils/get_model_device.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def get_model_device(model: Module | DataParallel | LightningModule) -> torch.device | None:
    """Get the device on which a PyTorch model is on.

    Args:
        model: The PyTorch model.

    Returns:
        The device ('cuda' or 'cpu') where the model is located.

    Raises:
        ValueError: If the model has no parameters.
    """
    # Deepspeed models already have a device property, so just return that
    if hasattr(model, "device"):
        return model.device

    # Tensors for lightning should not have device set (as lightning will handle this)
    if isinstance(model, LightningModule):
        return None

    # Check if the model has parameters
    if len(list(model.parameters())) == 0:
        exception_message = "The model has no parameters."
        raise ValueError(exception_message)

    # Return the device of the first parameter
    return next(model.parameters()).device