Skip to content

Commit

Permalink
Remove streamIndex from core and ops APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Feb 11, 2025
1 parent 0f50aba commit ffd07d4
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 96 deletions.
29 changes: 9 additions & 20 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,23 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frame_at_pts(Tensor(a!) decoder, float seconds) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
m.def(
"_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor");
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
m.def(
"get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str");
m.def("_get_json_ffmpeg_library_versions() -> str");
m.def(
"_test_frame_pts_equality(Tensor(a!) decoder, *, int stream_index, int frame_index, float pts_seconds_to_test) -> bool");
"_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool");
m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()");
}

Expand Down Expand Up @@ -251,18 +250,14 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
return makeOpsFrameOutput(result);
}

OpsFrameOutput get_frame_at_index(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
int64_t frame_index) {
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFrameAtIndex(frame_index);
return makeOpsFrameOutput(result);
}

OpsFrameBatchOutput get_frames_at_indices(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
at::IntArrayRef frame_indices) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<int64_t> frameIndicesVec(
Expand All @@ -273,7 +268,6 @@ OpsFrameBatchOutput get_frames_at_indices(

OpsFrameBatchOutput get_frames_in_range(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
int64_t start,
int64_t stop,
std::optional<int64_t> step) {
Expand All @@ -284,7 +278,6 @@ OpsFrameBatchOutput get_frames_in_range(

OpsFrameBatchOutput get_frames_by_pts(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
at::ArrayRef<double> timestamps) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
Expand All @@ -294,7 +287,6 @@ OpsFrameBatchOutput get_frames_by_pts(

OpsFrameBatchOutput get_frames_by_pts_in_range(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
double start_seconds,
double stop_seconds) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
Expand Down Expand Up @@ -327,17 +319,14 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {

bool _test_frame_pts_equality(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
int64_t frame_index,
double pts_seconds_to_test) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
return pts_seconds_to_test ==
videoDecoder->getPtsSecondsForFrame(frame_index);
}

torch::Tensor _get_key_frame_indices(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index) {
torch::Tensor _get_key_frame_indices(at::Tensor& decoder) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
return videoDecoder->getKeyFrameIndices();
}
Expand Down
14 changes: 3 additions & 11 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,10 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
// Return the frames at given ptss for a given stream
OpsFrameBatchOutput get_frames_by_pts(
at::Tensor& decoder,
int64_t stream_index,
at::ArrayRef<double> timestamps);

// Return the frame that is visible at a given index in the video.
OpsFrameOutput get_frame_at_index(
at::Tensor& decoder,
int64_t stream_index,
int64_t frame_index);
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index);

// Get the next frame from the video as a tuple that has the frame data, pts and
// duration as tensors.
Expand All @@ -101,14 +97,12 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder);
// Return the frames at given indices for a given stream
OpsFrameBatchOutput get_frames_at_indices(
at::Tensor& decoder,
int64_t stream_index,
at::IntArrayRef frame_indices);

// Return the frames inside a range as a single stacked Tensor. The range is
// defined as [start, stop).
OpsFrameBatchOutput get_frames_in_range(
at::Tensor& decoder,
int64_t stream_index,
int64_t start,
int64_t stop,
std::optional<int64_t> step = std::nullopt);
Expand All @@ -118,7 +112,6 @@ OpsFrameBatchOutput get_frames_in_range(
// order.
OpsFrameBatchOutput get_frames_by_pts_in_range(
at::Tensor& decoder,
int64_t stream_index,
double start_seconds,
double stop_seconds);

Expand All @@ -128,16 +121,15 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
// We want to make sure that the value is preserved exactly, bit-for-bit, during
// this process.
//
// Returns true if for the given decoder, in the stream stream_index, the pts
// Returns true if for the given decoder, the pts
// value when converted to seconds as a double is exactly pts_seconds_to_test.
// Returns false otherwise.
bool _test_frame_pts_equality(
at::Tensor& decoder,
int64_t stream_index,
int64_t frame_index,
double pts_seconds_to_test);

torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index);
torch::Tensor _get_key_frame_indices(at::Tensor& decoder);

// Get the metadata from the video as a string.
std::string get_json_metadata(at::Tensor& decoder);
Expand Down
17 changes: 5 additions & 12 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@ def _getitem_int(self, key: int) -> Tensor:
f"Index {key} is out of bounds; length is {self._num_frames}"
)

frame_data, *_ = core.get_frame_at_index(
self._decoder, frame_index=key, stream_index=self.stream_index
)
frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
return frame_data

def _getitem_slice(self, key: slice) -> Tensor:
Expand All @@ -163,7 +161,6 @@ def _getitem_slice(self, key: slice) -> Tensor:
start, stop, step = key.indices(len(self))
frame_data, *_ = core.get_frames_in_range(
self._decoder,
stream_index=self.stream_index,
start=start,
stop=stop,
step=step,
Expand All @@ -189,9 +186,7 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
)

def _get_key_frame_indices(self) -> list[int]:
return core._get_key_frame_indices(
self._decoder, stream_index=self.stream_index
)
return core._get_key_frame_indices(self._decoder)

def get_frame_at(self, index: int) -> Frame:
"""Return a single frame at the given index.
Expand All @@ -208,7 +203,7 @@ def get_frame_at(self, index: int) -> Frame:
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
)
data, pts_seconds, duration_seconds = core.get_frame_at_index(
self._decoder, frame_index=index, stream_index=self.stream_index
self._decoder, frame_index=index
)
return Frame(
data=data,
Expand All @@ -234,7 +229,7 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
"""

data, pts_seconds, duration_seconds = core.get_frames_at_indices(
self._decoder, stream_index=self.stream_index, frame_indices=indices
self._decoder, frame_indices=indices
)
return FrameBatch(
data=data,
Expand Down Expand Up @@ -268,7 +263,6 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
raise IndexError(f"Step ({step}) must be greater than 0.")
frames = core.get_frames_in_range(
self._decoder,
stream_index=self.stream_index,
start=start,
stop=stop,
step=step,
Expand Down Expand Up @@ -316,7 +310,7 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
FrameBatch: The frames that are played at ``seconds``.
"""
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
self._decoder, timestamps=seconds, stream_index=self.stream_index
self._decoder, timestamps=seconds
)
return FrameBatch(
data=data,
Expand Down Expand Up @@ -359,7 +353,6 @@ def get_frames_played_in_range(
)
frames = core.get_frames_by_pts_in_range(
self._decoder,
stream_index=self.stream_index,
start_seconds=start_seconds,
stop_seconds=stop_seconds,
)
Expand Down
Loading

0 comments on commit ffd07d4

Please sign in to comment.