Skip to content

Commit

Permalink
Update ndx-pose for Lightning pose (#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson authored Jan 13, 2025
1 parent 9301d7f commit 8046d0d
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ video = [
"opencv-python-headless>=4.8.1.78",
]
lightningpose = [
"ndx-pose>=0.1.1",
"ndx-pose>=0.2",
"neuroconv[video]",
]
medpc = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def get_metadata_schema(self) -> dict:
description=dict(type="string"),
scorer=dict(type="string"),
source_software=dict(type="string", default="LightningPose"),
camera_name=dict(type="string", default="CameraPoseEstimation"),
),
patternProperties={
"^(?!(name|description|scorer|source_software)$)[a-zA-Z0-9_]+$": dict(
"^(?!(name|description|scorer|source_software|camera_name)$)[a-zA-Z0-9_]+$": dict(
title="PoseEstimationSeries",
type="object",
properties=dict(name=dict(type="string"), description=dict(type="string")),
Expand Down Expand Up @@ -80,22 +81,15 @@ def __init__(
verbose : bool, default: True
controls verbosity. ``True`` by default.
"""
from importlib.metadata import version

# This import is to assure that the ndx_pose is in the global namespace when an pynwb.io object is created
# For more detail, see https://github.com/rly/ndx-pose/issues/36
import ndx_pose # noqa: F401
from packaging import version as version_parse

from neuroconv.datainterfaces.behavior.video.video_utils import (
VideoCaptureContext,
)

ndx_pose_version = version("ndx-pose")

if version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2.0"):
raise ImportError("The ndx-pose version must be less than 0.2.0.")

self._vc = VideoCaptureContext

self.file_path = Path(file_path)
Expand Down Expand Up @@ -170,6 +164,7 @@ def get_metadata(self) -> DeepDict:
description="Contains the pose estimation series for each keypoint.",
scorer=self.scorer_name,
source_software="LightningPose",
camera_name="CameraPoseEstimation",
)
for keypoint_name in self.keypoint_names:
keypoint_name_without_spaces = keypoint_name.replace(" ", "")
Expand Down Expand Up @@ -206,7 +201,7 @@ def add_to_nwbfile(
The description of how the confidence was computed, e.g., 'Softmax output of the deep neural network'.
stub_test : bool, default: False
"""
from ndx_pose import PoseEstimation, PoseEstimationSeries
from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton, Skeletons

metadata_copy = deepcopy(metadata)

Expand All @@ -223,15 +218,14 @@ def add_to_nwbfile(
original_video_name = str(self.original_video_file_path)
else:
original_video_name = metadata_copy["Behavior"]["Videos"][0]["name"]

pose_estimation_kwargs = dict(
name=pose_estimation_metadata["name"],
description=pose_estimation_metadata["description"],
source_software=pose_estimation_metadata["source_software"],
scorer=pose_estimation_metadata["scorer"],
original_videos=[original_video_name],
dimensions=[self.dimension],
)
camera_name = pose_estimation_metadata["camera_name"]
if camera_name in nwbfile.devices:
camera = nwbfile.devices[camera_name]
else:
camera = nwbfile.create_device(
name=camera_name,
description="Camera used for behavioral recording and pose estimation.",
)

pose_estimation_data = self.pose_estimation_data if not stub_test else self.pose_estimation_data.head(n=10)
timestamps = self.get_timestamps(stub_test=stub_test)
Expand Down Expand Up @@ -263,8 +257,28 @@ def add_to_nwbfile(

pose_estimation_series.append(PoseEstimationSeries(**pose_estimation_series_kwargs))

pose_estimation_kwargs.update(
# Add Skeleton(s)
nodes = [keypoint_name.replace(" ", "") for keypoint_name in self.keypoint_names]
subject = nwbfile.subject if nwbfile.subject is not None else None
name = f"Skeleton{pose_estimation_name}"
skeleton = Skeleton(name=name, nodes=nodes, subject=subject)
if "Skeletons" in behavior.data_interfaces:
skeletons = behavior.data_interfaces["Skeletons"]
skeletons.add_skeletons(skeleton)
else:
skeletons = Skeletons(skeletons=[skeleton])
behavior.add(skeletons)

pose_estimation_kwargs = dict(
name=pose_estimation_metadata["name"],
description=pose_estimation_metadata["description"],
source_software=pose_estimation_metadata["source_software"],
scorer=pose_estimation_metadata["scorer"],
original_videos=[original_video_name],
dimensions=[self.dimension],
pose_estimation_series=pose_estimation_series,
devices=[camera],
skeleton=skeleton,
)

if self.source_data["labeled_video_file_path"]:
Expand Down
9 changes: 1 addition & 8 deletions tests/test_on_data/behavior/test_behavior_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,9 @@

from importlib.metadata import version

from packaging import version as version_parse

ndx_pose_version = version("ndx-pose")


@pytest.mark.skipif(
version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2"
)
class TestLightningPoseDataInterface(DataInterfaceTestMixin, TemporalAlignmentMixin):
data_interface_cls = LightningPoseDataInterface
interface_kwargs = dict(
Expand Down Expand Up @@ -94,6 +89,7 @@ def setup_metadata(self, request):
description="Contains the pose estimation series for each keypoint.",
scorer="heatmap_tracker",
source_software="LightningPose",
camera_name="CameraPoseEstimation",
)
)
cls.expected_metadata[cls.pose_estimation_name].update(
Expand Down Expand Up @@ -165,9 +161,6 @@ def check_read_nwb(self, nwbfile_path: str):
assert_array_equal(pose_estimation_series.data[:], test_data[["x", "y"]].values)


@pytest.mark.skipif(
version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2"
)
class TestLightningPoseDataInterfaceWithStubTest(DataInterfaceTestMixin, TemporalAlignmentMixin):
data_interface_cls = LightningPoseDataInterface
interface_kwargs = dict(
Expand Down
10 changes: 1 addition & 9 deletions tests/test_on_data/behavior/test_lightningpose_converter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import shutil
import tempfile
from datetime import datetime
from importlib.metadata import version
from pathlib import Path
from warnings import warn

import pytest
from hdmf.testing import TestCase
from packaging import version
from packaging import version as version_parse
from pynwb import NWBHDF5IO
from pynwb.image import ImageSeries

Expand All @@ -19,12 +15,7 @@

from ..setup_paths import BEHAVIOR_DATA_PATH

ndx_pose_version = version("ndx-pose")


@pytest.mark.skipif(
version_parse.parse(ndx_pose_version) >= version_parse.parse("0.2"), reason="ndx_pose version is smaller than 0.2"
)
class TestLightningPoseConverter(TestCase):
@classmethod
def setUpClass(cls) -> None:
Expand Down Expand Up @@ -73,6 +64,7 @@ def setUpClass(cls) -> None:
description="Contains the pose estimation series for each keypoint.",
scorer="heatmap_tracker",
source_software="LightningPose",
camera_name="CameraPoseEstimation",
)

cls.pose_estimation_metadata.update(
Expand Down

0 comments on commit 8046d0d

Please sign in to comment.