From ffd07d4b26b3fa7e0c687ee565c997dab1d58363 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 11 Feb 2025 09:58:57 +0000 Subject: [PATCH] Remove streamIndex from core and ops APIs --- .../decoders/_core/VideoDecoderOps.cpp | 29 +++----- .../decoders/_core/VideoDecoderOps.h | 14 +--- src/torchcodec/decoders/_video_decoder.py | 17 ++--- test/decoders/test_video_decoder_ops.py | 74 ++++++------------- 4 files changed, 38 insertions(+), 96 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 398232b5..c9f7981c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -39,24 +39,23 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frame_at_pts(Tensor(a!) decoder, float seconds) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)"); + "get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); + "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); + "get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); + "get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)"); - m.def( - "_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor"); + "get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)"); + m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); m.def( "get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str"); m.def("_get_json_ffmpeg_library_versions() -> str"); m.def( - "_test_frame_pts_equality(Tensor(a!) decoder, *, int stream_index, int frame_index, float pts_seconds_to_test) -> bool"); + "_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool"); m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()"); } @@ -251,10 +250,7 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { return makeOpsFrameOutput(result); } -OpsFrameOutput get_frame_at_index( - at::Tensor& decoder, - [[maybe_unused]] int64_t stream_index, - int64_t frame_index) { +OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFrameAtIndex(frame_index); return makeOpsFrameOutput(result); @@ -262,7 +258,6 @@ OpsFrameOutput get_frame_at_index( OpsFrameBatchOutput get_frames_at_indices( at::Tensor& decoder, - [[maybe_unused]] int64_t stream_index, at::IntArrayRef frame_indices) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); std::vector frameIndicesVec( @@ -273,7 +268,6 @@ OpsFrameBatchOutput get_frames_at_indices( OpsFrameBatchOutput get_frames_in_range( at::Tensor& decoder, - [[maybe_unused]] int64_t stream_index, int64_t start, int64_t stop, std::optional step) { @@ -284,7 +278,6 @@ OpsFrameBatchOutput get_frames_in_range( OpsFrameBatchOutput get_frames_by_pts( at::Tensor& decoder, - [[maybe_unused]] int64_t stream_index, at::ArrayRef timestamps) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); std::vector timestampsVec(timestamps.begin(), timestamps.end()); @@ -294,7 +287,6 @@ OpsFrameBatchOutput get_frames_by_pts( OpsFrameBatchOutput get_frames_by_pts_in_range( at::Tensor& decoder, - [[maybe_unused]] int64_t stream_index, double start_seconds, double stop_seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -327,7 +319,6 @@ std::string mapToJson(const std::map& metadataMap) { bool _test_frame_pts_equality( at::Tensor& decoder, - [[maybe_unused]] int64_t stream_index, int64_t frame_index, double pts_seconds_to_test) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -335,9 +326,7 @@ bool _test_frame_pts_equality( videoDecoder->getPtsSecondsForFrame(frame_index); } -torch::Tensor _get_key_frame_indices( - at::Tensor& decoder, - [[maybe_unused]] int64_t stream_index) { +torch::Tensor _get_key_frame_indices(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); return videoDecoder->getKeyFrameIndices(); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 241d8098..8bdd05cd 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -85,14 +85,10 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds); // Return the frames at given ptss for a given stream OpsFrameBatchOutput get_frames_by_pts( at::Tensor& decoder, - int64_t stream_index, at::ArrayRef timestamps); // Return the frame that is visible at a given index in the video. -OpsFrameOutput get_frame_at_index( - at::Tensor& decoder, - int64_t stream_index, - int64_t frame_index); +OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index); // Get the next frame from the video as a tuple that has the frame data, pts and // duration as tensors. @@ -101,14 +97,12 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder); // Return the frames at given indices for a given stream OpsFrameBatchOutput get_frames_at_indices( at::Tensor& decoder, - int64_t stream_index, at::IntArrayRef frame_indices); // Return the frames inside a range as a single stacked Tensor. The range is // defined as [start, stop). OpsFrameBatchOutput get_frames_in_range( at::Tensor& decoder, - int64_t stream_index, int64_t start, int64_t stop, std::optional step = std::nullopt); @@ -118,7 +112,6 @@ OpsFrameBatchOutput get_frames_in_range( // order. OpsFrameBatchOutput get_frames_by_pts_in_range( at::Tensor& decoder, - int64_t stream_index, double start_seconds, double stop_seconds); @@ -128,16 +121,15 @@ OpsFrameBatchOutput get_frames_by_pts_in_range( // We want to make sure that the value is preserved exactly, bit-for-bit, during // this process. // -// Returns true if for the given decoder, in the stream stream_index, the pts +// Returns true if for the given decoder, the pts // value when converted to seconds as a double is exactly pts_seconds_to_test. // Returns false otherwise. bool _test_frame_pts_equality( at::Tensor& decoder, - int64_t stream_index, int64_t frame_index, double pts_seconds_to_test); -torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index); +torch::Tensor _get_key_frame_indices(at::Tensor& decoder); // Get the metadata from the video as a string. std::string get_json_metadata(at::Tensor& decoder); diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index ecf70d966..6ab59e0c 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -152,9 +152,7 @@ def _getitem_int(self, key: int) -> Tensor: f"Index {key} is out of bounds; length is {self._num_frames}" ) - frame_data, *_ = core.get_frame_at_index( - self._decoder, frame_index=key, stream_index=self.stream_index - ) + frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key) return frame_data def _getitem_slice(self, key: slice) -> Tensor: @@ -163,7 +161,6 @@ def _getitem_slice(self, key: slice) -> Tensor: start, stop, step = key.indices(len(self)) frame_data, *_ = core.get_frames_in_range( self._decoder, - stream_index=self.stream_index, start=start, stop=stop, step=step, @@ -189,9 +186,7 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor: ) def _get_key_frame_indices(self) -> list[int]: - return core._get_key_frame_indices( - self._decoder, stream_index=self.stream_index - ) + return core._get_key_frame_indices(self._decoder) def get_frame_at(self, index: int) -> Frame: """Return a single frame at the given index. @@ -208,7 +203,7 @@ def get_frame_at(self, index: int) -> Frame: f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})." ) data, pts_seconds, duration_seconds = core.get_frame_at_index( - self._decoder, frame_index=index, stream_index=self.stream_index + self._decoder, frame_index=index ) return Frame( data=data, @@ -234,7 +229,7 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch: """ data, pts_seconds, duration_seconds = core.get_frames_at_indices( - self._decoder, stream_index=self.stream_index, frame_indices=indices + self._decoder, frame_indices=indices ) return FrameBatch( data=data, @@ -268,7 +263,6 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc raise IndexError(f"Step ({step}) must be greater than 0.") frames = core.get_frames_in_range( self._decoder, - stream_index=self.stream_index, start=start, stop=stop, step=step, @@ -316,7 +310,7 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch: FrameBatch: The frames that are played at ``seconds``. """ data, pts_seconds, duration_seconds = core.get_frames_by_pts( - self._decoder, timestamps=seconds, stream_index=self.stream_index + self._decoder, timestamps=seconds ) return FrameBatch( data=data, @@ -359,7 +353,6 @@ def get_frames_played_in_range( ) frames = core.get_frames_by_pts_in_range( self._decoder, - stream_index=self.stream_index, start_seconds=start_seconds, stop_seconds=stop_seconds, ) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 9b41126f..c6fed986 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -120,11 +120,11 @@ def test_get_frame_at_pts(self, device): def test_get_frame_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) - frame0, _, _ = get_frame_at_index(decoder, stream_index=3, frame_index=0) + frame0, _, _ = get_frame_at_index(decoder, frame_index=0) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) assert_frames_equal(frame0, reference_frame0.to(device)) # The frame that is played at 6 seconds is frame 180 from a 0-based index. - frame6, _, _ = get_frame_at_index(decoder, stream_index=3, frame_index=180) + frame6, _, _ = get_frame_at_index(decoder, frame_index=180) reference_frame6 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS ) @@ -134,9 +134,7 @@ def test_get_frame_at_index(self, device): def test_get_frame_with_info_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) - frame6, pts, duration = get_frame_at_index( - decoder, stream_index=3, frame_index=180 - ) + frame6, pts, duration = get_frame_at_index(decoder, frame_index=180) reference_frame6 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS ) @@ -148,9 +146,7 @@ def test_get_frame_with_info_at_index(self, device): def test_get_frames_at_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) - frames0and180, *_ = get_frames_at_indices( - decoder, stream_index=3, frame_indices=[0, 180] - ) + frames0and180, *_ = get_frames_at_indices(decoder, frame_indices=[0, 180]) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) reference_frame180 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS @@ -162,20 +158,16 @@ def test_get_frames_at_indices(self, device): def test_get_frames_at_indices_unsorted_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, device=device) - stream_index = 3 frame_indices = [2, 0, 1, 0, 2] expected_frames = [ - get_frame_at_index( - decoder, stream_index=stream_index, frame_index=frame_index - )[0] + get_frame_at_index(decoder, frame_index=frame_index)[0] for frame_index in frame_indices ] frames, *_ = get_frames_at_indices( decoder, - stream_index=stream_index, frame_indices=frame_indices, ) for frame, expected_frame in zip(frames, expected_frames): @@ -193,7 +185,6 @@ def test_get_frames_at_indices_unsorted_indices(self, device): def test_get_frames_by_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, device=device) - stream_index = 3 # Note: 13.01 should give the last video frame for the NASA video timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] @@ -204,7 +195,6 @@ def test_get_frames_by_pts(self, device): frames, *_ = get_frames_by_pts( decoder, - stream_index=stream_index, timestamps=timestamps, ) for frame, expected_frame in zip(frames, expected_frames): @@ -233,12 +223,9 @@ def test_pts_apis_against_index_ref(self, device): num_frames = metadata_dict["numFrames"] assert num_frames == 390 - stream_index = 3 _, all_pts_seconds_ref, _ = zip( *[ - get_frame_at_index( - decoder, stream_index=stream_index, frame_index=frame_index - ) + get_frame_at_index(decoder, frame_index=frame_index) for frame_index in range(num_frames) ] ) @@ -254,7 +241,6 @@ def test_pts_apis_against_index_ref(self, device): _, pts_seconds, _ = get_frames_by_pts_in_range( decoder, - stream_index=stream_index, start_seconds=0, stop_seconds=all_pts_seconds_ref[-1] + 1e-4, ) @@ -264,7 +250,6 @@ def test_pts_apis_against_index_ref(self, device): *[ get_frames_by_pts_in_range( decoder, - stream_index=stream_index, start_seconds=pts, stop_seconds=pts + 1e-4, ) @@ -275,7 +260,7 @@ def test_pts_apis_against_index_ref(self, device): torch.testing.assert_close(pts_seconds, all_pts_seconds_ref, atol=0, rtol=0) _, pts_seconds, _ = get_frames_by_pts( - decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist() + decoder, timestamps=all_pts_seconds_ref.tolist() ) torch.testing.assert_close(pts_seconds, all_pts_seconds_ref, atol=0, rtol=0) @@ -286,47 +271,37 @@ def test_get_frames_in_range(self, device): # ensure that the degenerate case of a range of size 1 works ref_frame0 = NASA_VIDEO.get_frame_data_by_range(0, 1) - bulk_frame0, *_ = get_frames_in_range(decoder, stream_index=3, start=0, stop=1) + bulk_frame0, *_ = get_frames_in_range(decoder, start=0, stop=1) assert_frames_equal(bulk_frame0, ref_frame0.to(device)) ref_frame1 = NASA_VIDEO.get_frame_data_by_range(1, 2) - bulk_frame1, *_ = get_frames_in_range(decoder, stream_index=3, start=1, stop=2) + bulk_frame1, *_ = get_frames_in_range(decoder, start=1, stop=2) assert_frames_equal(bulk_frame1, ref_frame1.to(device)) ref_frame389 = NASA_VIDEO.get_frame_data_by_range(389, 390) - bulk_frame389, *_ = get_frames_in_range( - decoder, stream_index=3, start=389, stop=390 - ) + bulk_frame389, *_ = get_frames_in_range(decoder, start=389, stop=390) assert_frames_equal(bulk_frame389, ref_frame389.to(device)) # contiguous ranges ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9) - bulk_frames0_9, *_ = get_frames_in_range( - decoder, stream_index=3, start=0, stop=9 - ) + bulk_frames0_9, *_ = get_frames_in_range(decoder, start=0, stop=9) assert_frames_equal(bulk_frames0_9, ref_frames0_9.to(device)) ref_frames4_8 = NASA_VIDEO.get_frame_data_by_range(4, 8) - bulk_frames4_8, *_ = get_frames_in_range( - decoder, stream_index=3, start=4, stop=8 - ) + bulk_frames4_8, *_ = get_frames_in_range(decoder, start=4, stop=8) assert_frames_equal(bulk_frames4_8, ref_frames4_8.to(device)) # ranges with a stride ref_frames15_35 = NASA_VIDEO.get_frame_data_by_range(15, 36, 5) - bulk_frames15_35, *_ = get_frames_in_range( - decoder, stream_index=3, start=15, stop=36, step=5 - ) + bulk_frames15_35, *_ = get_frames_in_range(decoder, start=15, stop=36, step=5) assert_frames_equal(bulk_frames15_35, ref_frames15_35.to(device)) ref_frames0_9_2 = NASA_VIDEO.get_frame_data_by_range(0, 9, 2) - bulk_frames0_9_2, *_ = get_frames_in_range( - decoder, stream_index=3, start=0, stop=9, step=2 - ) + bulk_frames0_9_2, *_ = get_frames_in_range(decoder, start=0, stop=9, step=2) assert_frames_equal(bulk_frames0_9_2, ref_frames0_9_2.to(device)) # an empty range is valid! - empty_frame, *_ = get_frames_in_range(decoder, stream_index=3, start=5, stop=5) + empty_frame, *_ = get_frames_in_range(decoder, start=5, stop=5) assert_frames_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -481,9 +456,9 @@ def test_frame_pts_equality(self): # If this fails, there's a good chance that we accidentally truncated a 64-bit # floating point value to a 32-bit floating value. for i in range(390): - frame, pts, _ = get_frame_at_index(decoder, stream_index=3, frame_index=i) + frame, pts, _ = get_frame_at_index(decoder, frame_index=i) pts_is_equal = _test_frame_pts_equality( - decoder, stream_index=3, frame_index=i, pts_seconds_to_test=pts.item() + decoder, frame_index=i, pts_seconds_to_test=pts.item() ) assert pts_is_equal @@ -564,10 +539,7 @@ def test_color_conversion_library_with_dimension_order( frame0_ref = frame0_ref.permute(1, 2, 0) expected_shape = frame0_ref.shape - stream_index = 3 - frame0, *_ = get_frame_at_index( - decoder, stream_index=stream_index, frame_index=0 - ) + frame0, *_ = get_frame_at_index(decoder, frame_index=0) assert frame0.shape == expected_shape assert_frames_equal(frame0, frame0_ref) @@ -575,21 +547,17 @@ def test_color_conversion_library_with_dimension_order( assert frame0.shape == expected_shape assert_frames_equal(frame0, frame0_ref) - frames, *_ = get_frames_in_range( - decoder, stream_index=stream_index, start=0, stop=3 - ) + frames, *_ = get_frames_in_range(decoder, start=0, stop=3) assert frames.shape[1:] == expected_shape assert_frames_equal(frames[0], frame0_ref) frames, *_ = get_frames_by_pts_in_range( - decoder, stream_index=stream_index, start_seconds=0, stop_seconds=1 + decoder, start_seconds=0, stop_seconds=1 ) assert frames.shape[1:] == expected_shape assert_frames_equal(frames[0], frame0_ref) - frames, *_ = get_frames_at_indices( - decoder, stream_index=stream_index, frame_indices=[0, 1, 3, 4] - ) + frames, *_ = get_frames_at_indices(decoder, frame_indices=[0, 1, 3, 4]) assert frames.shape[1:] == expected_shape assert_frames_equal(frames[0], frame0_ref)