diff --git a/tests/test_inference_api.py b/tests/test_inference_api.py index ad5c3a5..5e3baed 100644 --- a/tests/test_inference_api.py +++ b/tests/test_inference_api.py @@ -22,6 +22,7 @@ def test_api(example_data): imgs, masks = example_data() model = Trackastra.from_pretrained( name="ctc", + device="cpu", ) # TODO store predictions already on trackastra.TrackGraph diff --git a/trackastra/cli.py b/trackastra/cli.py index 99db4eb..7d24cc5 100644 --- a/trackastra/cli.py +++ b/trackastra/cli.py @@ -70,7 +70,7 @@ def cli(): ) p_track.add_argument( "--device", - choices=["cuda", "mps", "cpu"], + choices=["cuda", "mps", "cpu", "automatic"], default=None, help=( "Device to use. If not set, tries to use cuda/mps if available, otherwise" @@ -89,20 +89,6 @@ def cli(): def _track_from_disk(args): - # if torch.cuda.is_available() and args.device == "cuda": - # device = "cuda" - # elif ( - # torch.backends.mps.is_available() - # and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None - # and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" - # and args.device == "mps" - # ): - # device = "mps" - # elif args.device == "cpu": - # device = "cpu" - # else: - # device = None - if args.model_pretrained is None == args.model_custom is None: raise ValueError( "Please pick a Trackastra model for tracking, either pretrained or a local" diff --git a/trackastra/model/model_api.py b/trackastra/model/model_api.py index e350716..ecdb552 100644 --- a/trackastra/model/model_api.py +++ b/trackastra/model/model_api.py @@ -25,7 +25,7 @@ def __init__( self, transformer: TrackingTransformer, train_args: dict, - device: Literal["cuda", "mps", "cpu", None] = None, + device: Literal["cuda", "mps", "cpu", "automatic", None] = None, ): if device == "cuda": if torch.cuda.is_available(): @@ -45,7 +45,7 @@ def __init__( self.device = "cpu" elif device == "cpu": self.device = "cpu" - else: + elif device == "automatic" or device is None: should_use_mps = ( torch.backends.mps.is_available() and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None @@ -60,6 +60,8 @@ def __init__( else "cpu" ) ) + else: + raise ValueError(f"Device {device} not recognized.") logger.info(f"Using device {self.device}")