Skip to content

Commit

Permalink
Fix test_pipelines_video_classification that was always failing (#3…
Browse files Browse the repository at this point in the history
…5842)

* Fix test_pipelines_video_classification that was always failing

* Update video pipeline docstring to reflect actual return type

---------

Co-authored-by: Louis Groux <[email protected]>
  • Loading branch information
CalOmnie and CalOmnie authored Jan 23, 2025
1 parent 328e2ae commit b5aaf87
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/transformers/pipelines/video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def __call__(self, inputs: Union[str, List[str]] = None, **kwargs):
post-processing.
Return:
A dictionary or a list of dictionaries containing result. If the input is a single video, will return a
dictionary, if the input is a list of several videos, will return a list of dictionaries corresponding to
the videos.
A list of dictionaries or a list of list of dictionaries containing result. If the input is a single video,
will return a list of `top_k` dictionaries, if the input is a list of several videos, will return a list of list of
`top_k` dictionaries corresponding to the videos.
The dictionaries contain the following keys:
Expand Down
9 changes: 4 additions & 5 deletions tests/pipelines/test_pipelines_video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,13 @@ def test_small_model_pt(self):
)

video_file_path = hf_hub_download(repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset")
outputs = video_classifier(video_file_path, top_k=2)
output = video_classifier(video_file_path, top_k=2)
self.assertEqual(
nested_simplify(outputs, decimals=4),
nested_simplify(output, decimals=4),
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
)
for output in outputs:
for element in output:
compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement)
for element in output:
compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement)

outputs = video_classifier(
[
Expand Down

0 comments on commit b5aaf87

Please sign in to comment.