Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Pose Reader for Bonsai-Sleap0.3 #372

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 47 additions & 29 deletions aeon/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,23 +286,36 @@ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/process
# `pattern` for this reader should typically be '<hpcnode>_<jobid>*'
super().__init__(pattern, columns=None)
self._model_root = model_root
self.config_file = None # requires reading the data file to be set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbhagatio @ttngu207 I don't think it makes sense to move config_file to a class field, because the config file only makes sense relative to a single shunk, i.e. a single call of the read method. The reason for this is we can change models and config files at any point during an experiment, e.g. when mice come in and out or if there is a model change.

What was the motivation behind this change? Is it necessary for some part of the ingestion pipeline?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glopesdev you are correct, making config_file a class property is misleading because the value of config_file is only valid within the read call and not reliably valid for the entire class.

The ingestion does need this but there are other workarounds - I'll make updates to fix this


def read(self, file: Path) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[1:])
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:])
config_file_dir = Path(self._model_root) / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names(config_file)
parts = self.get_bodyparts(config_file)
self.config_file = config_file

# Using bodyparts, assign column names to Harp register values, and read data in default format.
columns = ["class", "class_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
try: # Bonsai.Sleap0.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I think we would prefer to try with the new data format first (as it will be the default one going forward) and fallback to the legacy format if it doesn't work.

Copy link
Member Author

@jkbhagatio jkbhagatio Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't do this because as currently implemented, if they are swapped, then we erroneously read 0.2 data using 0.3 format.

i.e. we are able to (erroneously) read 0.2 data using 0.3 format, but if we try vice versa, it appropriately fails

bonsai_sleap_v = 0.2
columns = ["identity", "identity_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
except ValueError: # column mismatch; Bonsai.Sleap0.3
bonsai_sleap_v = 0.3
columns = ["identity"]
columns.extend([f"{identity}_likelihood" for identity in identities])
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

The error handling mechanism for different versions of Bonsai.Sleap is not robust. The code assumes that if a ValueError is raised, it's due to a column mismatch and thus the version must be 0.3. However, a ValueError could be raised due to other issues as well, leading to incorrect assumptions about the version. A more reliable way would be to explicitly check the version of Bonsai.Sleap being used.

-        try:  # Bonsai.Sleap0.2
-            bonsai_sleap_v = 0.2
-            columns = ["identity", "identity_likelihood"]
-            for part in parts:
-                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
-            self.columns = columns
-            data = super().read(file)
-        except ValueError:  # column mismatch; Bonsai.Sleap0.3
-            bonsai_sleap_v = 0.3
-            columns = ["identity"]
-            columns.extend([f"{identity}_likelihood" for identity in identities])
-            for part in parts:
-                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
-            self.columns = columns
-            data = super().read(file)
+        # Check the version of Bonsai.Sleap
+        bonsai_sleap_v = get_bonsai_sleap_version()  # This function needs to be implemented
+        if bonsai_sleap_v == 0.2:
+            columns = ["identity", "identity_likelihood"]
+            for part in parts:
+                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
+            self.columns = columns
+            data = super().read(file)
+        elif bonsai_sleap_v == 0.3:
+            columns = ["identity"]
+            columns.extend([f"{identity}_likelihood" for identity in identities])
+            for part in parts:
+                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
+            self.columns = columns
+            data = super().read(file)
+        else:
+            raise ValueError(f"Unsupported Bonsai.Sleap version: {bonsai_sleap_v}")


# Drop any repeat parts.
unique_parts, unique_idxs = np.unique(parts, return_index=True)
Expand All @@ -315,13 +328,23 @@ def read(self, file: Path) -> pd.DataFrame:
parts = unique_parts

# Set new columns, and reformat `data`.
data = self.class_int2str(data, config_file_dir)
n_parts = len(parts)
part_data_list = [pd.DataFrame()] * n_parts
new_columns = ["class", "class_likelihood", "part", "x", "y", "part_likelihood"]
new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]
new_data = pd.DataFrame(columns=new_columns)
for i, part in enumerate(parts):
part_columns = ["class", "class_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
part_columns = columns[0 : (len(identities) + 1)] if bonsai_sleap_v == 0.3 else columns[0:2]
part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
part_data = pd.DataFrame(data[part_columns])
if bonsai_sleap_v == 0.3: # combine all identity_likelihood cols into a single col as dict
part_data["identity_likelihood"] = part_data.apply(
lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1
)
part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True)
part_data = part_data[ # reorder columns
["identity", "identity_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
]
part_data.insert(2, "part", part)
part_data.columns = new_columns
part_data_list[i] = part_data
Expand Down Expand Up @@ -357,12 +380,23 @@ def get_bodyparts(self, file: Path) -> list[str]:
raise KeyError(f"Cannot find bodyparts in {file}.") from err
return parts

def class_int2str(self, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
config_file = self.get_config_file(config_file_dir)
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
return data

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

The class_int2str method is now an instance method, but it seems like it could still be a class method as before. It doesn't use any instance-specific data and only operates on the input parameters. This would improve modularity and make the method easier to test independently.

-    def class_int2str(self, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame:
+    @classmethod
+    def class_int2str(cls, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame:


@classmethod
def get_config_file(
cls,
config_file_dir: Path,
config_file_names: None | list[str] = None,
) -> Path:
def get_config_file(cls, config_file_dir: Path, config_file_names: None | list[str] = None) -> Path:
"""Returns the config file from a model's config directory."""
if config_file_names is None:
config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list)
Expand All @@ -375,22 +409,6 @@ def get_config_file(
raise FileNotFoundError(f"Cannot find config file in {config_file_dir}")
return config_file

@classmethod
def class_int2str(cls, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
config_file = cls.get_config_file(config_file_dir)
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["class"] == i, "class"] = subj
return data


def from_dict(data, pattern=None):
reader_type = data.get("type", None)
Expand Down
Loading