Skip to content

Commit

Permalink
Add automatic device choice explicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
bentaculum committed Jul 2, 2024
1 parent a7b3c27 commit 6c0016e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 17 deletions.
1 change: 1 addition & 0 deletions tests/test_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_api(example_data):
imgs, masks = example_data()
model = Trackastra.from_pretrained(
name="ctc",
device="cpu",
)

# TODO store predictions already on trackastra.TrackGraph
Expand Down
16 changes: 1 addition & 15 deletions trackastra/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def cli():
)
p_track.add_argument(
"--device",
choices=["cuda", "mps", "cpu"],
choices=["cuda", "mps", "cpu", "automatic"],
default=None,
help=(
"Device to use. If not set, tries to use cuda/mps if available, otherwise"
Expand All @@ -89,20 +89,6 @@ def cli():


def _track_from_disk(args):
# if torch.cuda.is_available() and args.device == "cuda":
# device = "cuda"
# elif (
# torch.backends.mps.is_available()
# and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
# and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
# and args.device == "mps"
# ):
# device = "mps"
# elif args.device == "cpu":
# device = "cpu"
# else:
# device = None

if args.model_pretrained is None == args.model_custom is None:
raise ValueError(
"Please pick a Trackastra model for tracking, either pretrained or a local"
Expand Down
6 changes: 4 additions & 2 deletions trackastra/model/model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
self,
transformer: TrackingTransformer,
train_args: dict,
device: Literal["cuda", "mps", "cpu", None] = None,
device: Literal["cuda", "mps", "cpu", "automatic", None] = None,
):
if device == "cuda":
if torch.cuda.is_available():
Expand All @@ -45,7 +45,7 @@ def __init__(
self.device = "cpu"
elif device == "cpu":
self.device = "cpu"
else:
elif device == "automatic" or device is None:
should_use_mps = (
torch.backends.mps.is_available()
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
Expand All @@ -60,6 +60,8 @@ def __init__(
else "cpu"
)
)
else:
raise ValueError(f"Device {device} not recognized.")

logger.info(f"Using device {self.device}")

Expand Down

0 comments on commit 6c0016e

Please sign in to comment.