Skip to content

Commit

Permalink
refactor(runtime): simplify runtime options configuration and add uni…
Browse files Browse the repository at this point in the history
…t tests for get_runtime function
  • Loading branch information
giuseppeambrosio97 committed Jan 7, 2025
1 parent fbcf145 commit 295021d
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 19 deletions.
27 changes: 8 additions & 19 deletions focoos/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import numpy as np
import onnxruntime as ort
import supervision as sv
from PIL import Image

from focoos.ports import (
FocoosTask,
Expand Down Expand Up @@ -373,22 +372,12 @@ def get_runtime(
Returns:
ONNXRuntime: A fully configured ONNXRuntime instance.
"""
if runtime_type == RuntimeTypes.ONNX_CUDA32:
opts = OnnxEngineOpts(
cuda=True, verbose=False, fp16=False, warmup_iter=warmup_iter
)
elif runtime_type == RuntimeTypes.ONNX_TRT32:
opts = OnnxEngineOpts(
cuda=False, verbose=False, trt=True, fp16=False, warmup_iter=warmup_iter
)
elif runtime_type == RuntimeTypes.ONNX_TRT16:
opts = OnnxEngineOpts(
cuda=False, verbose=False, trt=True, fp16=True, warmup_iter=warmup_iter
)
elif runtime_type == RuntimeTypes.ONNX_CPU:
opts = OnnxEngineOpts(cuda=False, verbose=False, warmup_iter=warmup_iter)
elif runtime_type == RuntimeTypes.ONNX_COREML:
opts = OnnxEngineOpts(
cuda=False, verbose=False, coreml=True, warmup_iter=warmup_iter
)
opts = OnnxEngineOpts(
cuda=runtime_type == RuntimeTypes.ONNX_CUDA32,
trt=runtime_type in [RuntimeTypes.ONNX_TRT32, RuntimeTypes.ONNX_TRT16],
fp16=runtime_type == RuntimeTypes.ONNX_TRT16,
warmup_iter=warmup_iter,
coreml=runtime_type == RuntimeTypes.ONNX_COREML,
verbose=False,
)
return ONNXRuntime(model_path, opts, model_metadata)
104 changes: 104 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pathlib
from unittest.mock import MagicMock

import pytest
from pytest_mock import MockerFixture

from focoos.ports import ModelMetadata, OnnxEngineOpts, RuntimeTypes
from focoos.runtime import ONNXRuntime, get_runtime


@pytest.mark.parametrize(
"runtime_type, expected_opts",
[
(
RuntimeTypes.ONNX_CUDA32,
OnnxEngineOpts(
cuda=True,
trt=False,
fp16=False,
coreml=False,
verbose=False,
warmup_iter=2,
),
),
(
RuntimeTypes.ONNX_TRT32,
OnnxEngineOpts(
cuda=False,
trt=True,
fp16=False,
coreml=False,
verbose=False,
warmup_iter=2,
),
),
(
RuntimeTypes.ONNX_TRT16,
OnnxEngineOpts(
cuda=False,
trt=True,
fp16=True,
coreml=False,
verbose=False,
warmup_iter=2,
),
),
(
RuntimeTypes.ONNX_CPU,
OnnxEngineOpts(
cuda=False,
trt=False,
fp16=False,
coreml=False,
verbose=False,
warmup_iter=2,
),
),
(
RuntimeTypes.ONNX_COREML,
OnnxEngineOpts(
cuda=False,
trt=False,
fp16=False,
coreml=True,
verbose=False,
warmup_iter=2,
),
),
],
)
def test_get_run_time(mocker: MockerFixture, tmp_path, runtime_type, expected_opts):
# mock model path
model_path = pathlib.Path(tmp_path) / "fakeref" / "model.onnx"
model_path.parent.mkdir(parents=True, exist_ok=True)
model_path.touch()
model_path = model_path.as_posix()

# mock model metadata
mock_model_metadata = MagicMock(spec=ModelMetadata)

# mock opts
mock_onnxruntime_class = mocker.patch("focoos.runtime.ONNXRuntime", autospec=True)
mock_onnxruntime_class.return_value = MagicMock(
spec=ONNXRuntime, opts=expected_opts
)

# warmup_iter
warmup_iter = 2

# call the function to test
onnx_runtime = get_runtime(
runtime_type=runtime_type,
model_path=model_path,
model_metadata=mock_model_metadata,
warmup_iter=warmup_iter,
)

# assertions
assert onnx_runtime is not None
mock_onnxruntime_class.assert_called_once_with(
model_path,
expected_opts,
mock_model_metadata,
)

0 comments on commit 295021d

Please sign in to comment.