Skip to content

Commit

Permalink
handle device properly
Browse files Browse the repository at this point in the history
  • Loading branch information
fabian-sp committed Apr 22, 2024
1 parent ccbd52b commit 39cf7da
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/ncopt/functions/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,17 @@ def __init__(
self.prepare_inputs = prepare_inputs
self.is_differentiable = is_differentiable

# If no device is provided, set it to the same as the first model parameter
# this might fail for distributed models
# If no device is provided, set it to the same as the model parameters
# if model has no parameters, we set device to cpu
if not self.device:
if sum(p.numel() for p in model.parameters() if p.requires_grad) > 0:
self.device = next(model.parameters()).device
devices = set([p.device for p in model.parameters()])
if len(devices) == 1:
self.device = devices.pop()
else:
raise KeyError(
"Model parameters lie on more than one device. Currently not supported."
)
else:
self.device = torch.device("cpu")

Expand Down

0 comments on commit 39cf7da

Please sign in to comment.