Get the device that the model is on¤
Get the device that the model is on.
get_model_device(model)
¤
Get the device on which a PyTorch model is on.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module | DataParallel | LightningModule
|
The PyTorch model. |
required |
Returns:
Type | Description |
---|---|
device | None
|
The device ('cuda' or 'cpu') where the model is located. |
Raises:
Type | Description |
---|---|
ValueError
|
If the model has no parameters. |
Source code in sparse_autoencoder/train/utils/get_model_device.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
|