Skip to content

Commit

Permalink
Add option for the user to pass in ffmpeg thread count (#291)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
ahmadsharif1 and NicolasHug authored Oct 28, 2024
1 parent fedfeba commit 59af4b7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
11 changes: 10 additions & 1 deletion src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class VideoDecoder:
This can be either "NCHW" (default) or "NHWC", where N is the batch
size, C is the number of channels, H is the height, and W is the
width of the frames.
num_ffmpeg_threads (int, optional): The number of threads to use for decoding.
Use 1 for single-threaded decoding which may be best if you are running multiple
instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded
decoding which is best if you are running a single instance of ``VideoDecoder``.
Default: 1.
.. note::
Expand All @@ -58,6 +63,7 @@ def __init__(
*,
stream_index: Optional[int] = None,
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
num_ffmpeg_threads: int = 1,
):
if isinstance(source, str):
self._decoder = core.create_from_file(source)
Expand All @@ -82,7 +88,10 @@ def __init__(

core.scan_all_streams_to_update_metadata(self._decoder)
core.add_video_stream(
self._decoder, stream_index=stream_index, dimension_order=dimension_order
self._decoder,
stream_index=stream_index,
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
)

self.metadata, self.stream_index = _get_and_validate_stream_metadata(
Expand Down
5 changes: 3 additions & 2 deletions test/decoders/test_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def test_create_fails(self):
with pytest.raises(ValueError, match="No valid stream found"):
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=1) # noqa

def test_getitem_int(self):
decoder = VideoDecoder(NASA_VIDEO.path)
@pytest.mark.parametrize("num_ffmpeg_threads", (1, 4))
def test_getitem_int(self, num_ffmpeg_threads):
decoder = VideoDecoder(NASA_VIDEO.path, num_ffmpeg_threads=num_ffmpeg_threads)

ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1)
Expand Down

0 comments on commit 59af4b7

Please sign in to comment.