Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Pose Reader for Bonsai-Sleap0.3 #372

Merged
merged 3 commits into from
Jul 3, 2024
Merged

Conversation

jkbhagatio
Copy link
Member

@jkbhagatio jkbhagatio commented Jun 27, 2024

Fixes #364
Fixes #365
Fixes #328
Fixes #308

Summary by CodeRabbit

  • New Feature: Added support for different versions of Bonsai.Sleap (0.2 and 0.3) in the PoseReader class, enhancing compatibility with various data formats.
  • Refactor: Improved data processing and error handling mechanisms in the PoseReader class to ensure accurate and efficient reading of pose data.
  • Refactor: Moved the class_int2str method within the PoseReader class for better encapsulation and code organization.
  • Bug Fix: Fixed issues related to identity likelihood data handling and column processing in the read method, ensuring correct data interpretation and ordering.

@jkbhagatio jkbhagatio added the bug Something isn't working label Jun 27, 2024
Copy link

github-actions bot commented Jun 27, 2024

Image description CodeRabbit

Walkthrough

The changes primarily focus on enhancing the PoseReader class in the aeon/io/reader.py file to handle different versions of Bonsai.Sleap (0.2 and 0.3). The modifications include adjustments in column names, data processing, error handling for different versions, and reordering of columns in the data frame. A new method class_int2str has been introduced for converting class integers to strings based on the configuration file.

Changes

File Summary
aeon/io/reader.py The PoseReader class has been updated to handle different versions of Bonsai.Sleap (0.2 and 0.3). This includes changes in column names, data processing, and error handling. The read method has been modified to handle identity likelihood data differently. A new method class_int2str has been added for converting class integers to strings. Columns are now processed and reordered differently in the data frame.

Uplevel your code reviews with CodeRabbit Pro

CodeRabbit Pro

If you like this project, please support us by purchasing the Pro version. The Pro version has advanced context, superior noise reduction and several proprietary improvements compared to the open source version. Moreover, CodeRabbit Pro is free for open source projects.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

Commits Files that changed from the base of the PR and between 25cc4b7 and 0978166 commits.
Files selected (1)
  • aeon/io/reader.py (4)
Review comments generated (4)
  • Review: 2
  • LGTM: 2

Tips

Chat with Image description CodeRabbit Bot (@coderabbitai)

  • Reply on review comments left by this bot to ask follow-up questions. A review comment is a comment on a diff or a file.
  • Invite the bot into a review comment chain by tagging @coderabbitai in a reply.

Code suggestions

  • The bot may make code suggestions, but please review them carefully before committing since the line number ranges may be misaligned.
  • You can edit the comment made by the bot and manually tweak the suggestion if it is slightly off.

Pausing incremental reviews

  • Add @coderabbitai: ignore anywhere in the PR description to pause further reviews from the bot.

Comment on lines 289 to 318
self.config_file = None # requires reading the data file to be set

def read(self, file: Path) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[1:])
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:])
config_file_dir = Path(self._model_root) / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names(config_file)
parts = self.get_bodyparts(config_file)
self.config_file = config_file

# Using bodyparts, assign column names to Harp register values, and read data in default format.
columns = ["class", "class_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
try: # Bonsai.Sleap0.2
bonsai_sleap_v = 0.2
columns = ["identity", "identity_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
except ValueError: # column mismatch; Bonsai.Sleap0.3
bonsai_sleap_v = 0.3
columns = ["identity"]
columns.extend([f"{identity}_likelihood" for identity in identities])
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

The error handling mechanism for different versions of Bonsai.Sleap is not robust. The code assumes that if a ValueError is raised, it's due to a column mismatch and thus the version must be 0.3. However, a ValueError could be raised due to other issues as well, leading to incorrect assumptions about the version. A more reliable way would be to explicitly check the version of Bonsai.Sleap being used.

-        try:  # Bonsai.Sleap0.2
-            bonsai_sleap_v = 0.2
-            columns = ["identity", "identity_likelihood"]
-            for part in parts:
-                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
-            self.columns = columns
-            data = super().read(file)
-        except ValueError:  # column mismatch; Bonsai.Sleap0.3
-            bonsai_sleap_v = 0.3
-            columns = ["identity"]
-            columns.extend([f"{identity}_likelihood" for identity in identities])
-            for part in parts:
-                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
-            self.columns = columns
-            data = super().read(file)
+        # Check the version of Bonsai.Sleap
+        bonsai_sleap_v = get_bonsai_sleap_version()  # This function needs to be implemented
+        if bonsai_sleap_v == 0.2:
+            columns = ["identity", "identity_likelihood"]
+            for part in parts:
+                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
+            self.columns = columns
+            data = super().read(file)
+        elif bonsai_sleap_v == 0.3:
+            columns = ["identity"]
+            columns.extend([f"{identity}_likelihood" for identity in identities])
+            for part in parts:
+                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
+            self.columns = columns
+            data = super().read(file)
+        else:
+            raise ValueError(f"Unsupported Bonsai.Sleap version: {bonsai_sleap_v}")

Comment on lines 383 to 396
def class_int2str(self, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
config_file = self.get_config_file(config_file_dir)
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
return data

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

The class_int2str method is now an instance method, but it seems like it could still be a class method as before. It doesn't use any instance-specific data and only operates on the input parameters. This would improve modularity and make the method easier to test independently.

-    def class_int2str(self, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame:
+    @classmethod
+    def class_int2str(cls, data: pd.DataFrame, config_file_dir: Path) -> pd.DataFrame:

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

Commits Files that changed from the base of the PR and between 0978166 and 4a72c47 commits.
Files selected (1)
  • aeon/io/reader.py (3)
Review comments generated (6)
  • Review: 5
  • LGTM: 1

Tips

Chat with Image description CodeRabbit Bot (@coderabbitai)

  • Reply on review comments left by this bot to ask follow-up questions. A review comment is a comment on a diff or a file.
  • Invite the bot into a review comment chain by tagging @coderabbitai in a reply.

Code suggestions

  • The bot may make code suggestions, but please review them carefully before committing since the line number ranges may be misaligned.
  • You can edit the comment made by the bot and manually tweak the suggestion if it is slightly off.

Pausing incremental reviews

  • Add @coderabbitai: ignore anywhere in the PR description to pause further reviews from the bot.

Comment on lines +289 to +317
self.config_file = None # requires reading the data file to be set

def read(self, file: Path) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[1:])
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:])
config_file_dir = Path(self._model_root) / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
config_file = self.get_config_file(config_file_dir)
parts = self.get_bodyparts(config_file)
self.config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names()
parts = self.get_bodyparts()

# Using bodyparts, assign column names to Harp register values, and read data in default format.
columns = ["class", "class_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
try: # Bonsai.Sleap0.2
bonsai_sleap_v = 0.2
columns = ["identity", "identity_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
except ValueError: # column mismatch; Bonsai.Sleap0.3
bonsai_sleap_v = 0.3
columns = ["identity"]
columns.extend([f"{identity}_likelihood" for identity in identities])
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
Copy link

@github-actions github-actions bot Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

The read method is trying to handle two different versions of Bonsai.Sleap by using a try-except block. This approach can lead to confusion and potential bugs in the future as it's not clear which part of the code inside the try block is expected to raise a ValueError. It would be better to explicitly check the version of Bonsai.Sleap and then execute the appropriate code. This will make the code more readable and maintainable.

-        try:  # Bonsai.Sleap0.2
-            bonsai_sleap_v = 0.2
-            columns = ["identity", "identity_likelihood"]
-            for part in parts:
-                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
-            self.columns = columns
-            data = super().read(file)
-        except ValueError:  # column mismatch; Bonsai.Sleap0.3
-            bonsai_sleap_v = 0.3
-            columns = ["identity"]
-            columns.extend([f"{identity}_likelihood" for identity in identities])
-            for part in parts:
-                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
-            self.columns = columns
-            data = super().read(file)
+        if self.bonsai_sleap_version == 0.2:
+            columns = ["identity", "identity_likelihood"]
+            for part in parts:
+                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
+            self.columns = columns
+            data = super().read(file)
+        elif self.bonsai_sleap_version == 0.3:
+            columns = ["identity"]
+            columns.extend([f"{identity}_likelihood" for identity in identities])
+            for part in parts:
+                columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
+            self.columns = columns
+            data = super().read(file)
+        else:
+            raise ValueError(f"Unsupported Bonsai.Sleap version: {self.bonsai_sleap_version}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could potentially do this if we access the experimental metadata.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

@glopesdev You're correct. Accessing the experimental metadata to determine the Bonsai.Sleap version would be a more explicit and maintainable approach. This way, we can avoid potential bugs that might arise from relying on catching a ValueError. Here's how you could modify the code:

def read(self, file: Path) -> pd.DataFrame:
    """Reads data from the Harp-binarized tracking file."""
    # Get config file from `file`, then bodyparts from config file.
    model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:])
    config_file_dir = Path(self._model_root) / model_dir
    if not config_file_dir.exists():
        raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
    self.config_file = self.get_config_file(config_file_dir)
    identities = self.get_class_names()
    parts = self.get_bodyparts()

    # Determine Bonsai.Sleap version from experimental metadata
    bonsai_sleap_v = self.get_bonsai_sleap_version()

    # Using bodyparts, assign column names to Harp register values, and read data in default format.
    if bonsai_sleap_v == 0.2:
        columns = ["identity", "identity_likelihood"]
        for part in parts:
            columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
        self.columns = columns
        data = super().read(file)
    elif bonsai_sleap_v == 0.3:
        columns = ["identity"]
        columns.extend([f"{identity}_likelihood" for identity in identities])
        for part in parts:
            columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
        self.columns = columns
        data = super().read(file)
    else:
        raise ValueError(f"Unsupported Bonsai.Sleap version: {bonsai_sleap_v}")

    # Rest of the code...

You would need to implement the get_bonsai_sleap_version method, which should access the experimental metadata and return the Bonsai.Sleap version. This approach makes it clear what versions are supported and how each version is handled, improving maintainability and readability.

Comment on lines +330 to 351
data = self.class_int2str(data)
n_parts = len(parts)
part_data_list = [pd.DataFrame()] * n_parts
new_columns = ["class", "class_likelihood", "part", "x", "y", "part_likelihood"]
new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]
new_data = pd.DataFrame(columns=new_columns)
for i, part in enumerate(parts):
part_columns = ["class", "class_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
part_columns = columns[0 : (len(identities) + 1)] if bonsai_sleap_v == 0.3 else columns[0:2]
part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
part_data = pd.DataFrame(data[part_columns])
if bonsai_sleap_v == 0.3: # combine all identity_likelihood cols into a single col as dict
part_data["identity_likelihood"] = part_data.apply(
lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1
)
part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True)
part_data = part_data[ # reorder columns
["identity", "identity_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
]
part_data.insert(2, "part", part)
part_data.columns = new_columns
part_data_list[i] = part_data
new_data = pd.concat(part_data_list)
return new_data.sort_index()
Copy link

@github-actions github-actions bot Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

The code is creating a new DataFrame for each part and then concatenating them. This can be inefficient when dealing with large datasets as it involves creating many temporary DataFrames. A more efficient approach would be to create a list of dictionaries and then convert that list into a DataFrame.

-         part_data_list = [pd.DataFrame()] * n_parts
-         new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]
-         new_data = pd.DataFrame(columns=new_columns)
-         for i, part in enumerate(parts):
-            part_columns = columns[0 : (len(identities) + 1)] if bonsai_sleap_v == 0.3 else columns[0:2]
-            part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
-             part_data = pd.DataFrame(data[part_columns])
-            if bonsai_sleap_v == 0.3:  # combine all identity_likelihood cols into a single col as dict
-                part_data["identity_likelihood"] = part_data.apply(
-                    lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1
-                )
-                part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True)
-                part_data = part_data[  # reorder columns
-                    ["identity", "identity_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
-                ]
-             part_data.insert(2, "part", part)
-             part_data.columns = new_columns
-             part_data_list[i] = part_data
-         new_data = pd.concat(part_data_list)
+         data_dicts = []
+         for i, part in enumerate(parts):
+            part_columns = columns[0 : (len(identities) + 1)] if bonsai_sleap_v == 0.3 else columns[0:2]
+            part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
+            for _, row in data[part_columns].iterrows():
+                if bonsai_sleap_v == 0.3:
+                    identity_likelihood = {identity: row[f"{identity}_likelihood"] for identity in identities}
+                else:
+                    identity_likelihood = row["identity_likelihood"]
+                data_dict = {
+                    "identity": row["identity"],
+                    "identity_likelihood": identity_likelihood,
+                    "part": part,
+                    "x": row[f"{part}_x"],
+                    "y": row[f"{part}_y"],
+                    "part_likelihood": row[f"{part}_likelihood"]
+                }
+                data_dicts.append(data_dict)
+         new_data = pd.DataFrame(data_dicts)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this would be more efficient as it is using iterrows instead of vectorized DataFrame operations. In any case, the golden rule for any optimization is benchmarks.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

@glopesdev You're correct that using iterrows can be slower than vectorized operations. However, the original code is creating a new DataFrame for each part and then concatenating them, which can also be inefficient when dealing with large datasets as it involves creating many temporary DataFrames.

The proposed change aims to reduce the number of temporary DataFrames created by instead building a list of dictionaries and then converting that list into a DataFrame in one go. While this does involve iterating over the rows of the DataFrame, the overhead may be offset by the reduction in temporary DataFrame creation.

That being said, you're absolutely right that the best way to determine which approach is more efficient would be to benchmark both methods on your specific data and use case. This will give you a clear indication of which method performs better in practice.

Here's a simple way to benchmark the two methods using Python's built-in timeit module:

import timeit

# Original method
start_time = timeit.default_timer()
# ... original method code here ...
elapsed = timeit.default_timer() - start_time
print(f"Original method took {elapsed} seconds")

# Proposed method
start_time = timeit.default_timer()
# ... proposed method code here ...
elapsed = timeit.default_timer() - start_time
print(f"Proposed method took {elapsed} seconds")

This will print out the time taken by each method, allowing you to compare their performance directly.

Comment on lines +353 to 365
def get_class_names(self) -> list[str]:
"""Returns a list of classes from a model's config file."""
classes = None
with open(file) as f:
with open(self.config_file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
if self.config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "class_vectors")["classes"]
except KeyError as err:
if not classes:
raise KeyError(f"Cannot find class_vectors in {file}.") from err
raise KeyError(f"Cannot find class_vectors in {self.config_file}.") from err
return classes
Copy link

@github-actions github-actions bot Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

The get_class_names method is reading the config file every time it's called. This can be inefficient if the method is called multiple times. It would be better to read the config file once and store it as an instance variable.

-        with open(self.config_file) as f:
-             config = json.load(f)
+        if not hasattr(self, 'config'):
+            with open(self.config_file) as f:
+                self.config = json.load(f)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the AI completely misunderstood the reason behind the config_file variable and is suggesting to cache it, which is exactly what we should never do.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

@glopesdev Thank you for your feedback. I understand your concern about caching the config_file variable. The suggestion was made with the intention of reducing the number of file reads if the method is called multiple times, which can be a performance bottleneck. However, I agree that in certain scenarios, especially where the config file might change during the lifetime of the object, this approach could lead to outdated data being used.

A better approach might be to provide a method to refresh the cached configuration when needed. This way, we can still benefit from reduced file reads while ensuring the configuration data is up-to-date when necessary.

Here's an example:

def refresh_config(self):
    with open(self.config_file) as f:
        self.config = json.load(f)

def get_class_names(self):
    if not hasattr(self, 'config'):
        self.refresh_config()
    # rest of the code...

With this approach, you can manually call refresh_config whenever you know the config file has changed. If this doesn't suit your use case, please disregard the suggestion. Your understanding of your application's needs and behavior is paramount.

Comment on lines +367 to +379
def get_bodyparts(self) -> list[str]:
"""Returns a list of bodyparts from a model's config file."""
parts = []
with open(file) as f:
with open(self.config_file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
if self.config_file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
parts = [util.find_nested_key(heads, "anchor_part")]
parts += util.find_nested_key(heads, "part_names")
except KeyError as err:
if not parts:
raise KeyError(f"Cannot find bodyparts in {file}.") from err
raise KeyError(f"Cannot find bodyparts in {self.config_file}.") from err
Copy link

@github-actions github-actions bot Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

The same issue is present in the get_bodyparts method. It should also use the stored config file.

-        with open(self.config_file) as f:
-             config = json.load(f)
+        if not hasattr(self, 'config'):
+            with open(self.config_file) as f:
+                self.config = json.load(f)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, we should also never do this. However, if this was a local variable, we could actually load the config once outside and pass it to these methods, so indeed in the context of a single read call we could actually cache the config.

So, right about caching, but wrong about cache invalidation, which I guess is fair since cache invalidation is the second hardest problem in all of computer science.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

@glopesdev I agree with your point. Caching the config file content as a class attribute could lead to stale data if the config file changes during the lifetime of the object. However, in the context of a single read call, it makes sense to load the config once and pass it to the methods that need it. This would avoid unnecessary disk I/O operations and improve performance.

Here's how we can refactor the code:

def read(self, file: Path) -> pd.DataFrame:
    # Load config file once
    with open(self.config_file) as f:
        config = json.load(f)

    identities = self.get_class_names(config)
    parts = self.get_bodyparts(config)

    # Rest of the method...

def get_class_names(self, config: dict) -> list[str]:
    """Returns a list of classes from a model's config."""
    # Use passed config instead of loading from file

def get_bodyparts(self, config: dict) -> list[str]:
    """Returns a list of bodyparts from a model's config."""
    # Use passed config instead of loading from file

This way, we ensure that the config is consistent within a single read call and we don't keep potentially stale data around.

Comment on lines +382 to +394
def class_int2str(self, data: pd.DataFrame) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
if self.config_file.stem == "confmap_config": # SLEAP
with open(self.config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {self.config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
return data
Copy link

@github-actions github-actions bot Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

The class_int2str method is also reading the config file every time it's called. It should use the stored config file.

-        with open(self.config_file) as f:
-            config = json.load(f)
+        if not hasattr(self, 'config'):
+            with open(self.config_file) as f:
+                self.config = json.load(f)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, class_int2str should be made static again, and the preloaded config passed into it, as recommended above.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image description CodeRabbit

@glopesdev I agree with your suggestion. Making class_int2str static again and passing the preloaded config to it would indeed improve performance by avoiding unnecessary file reads. Here's how the updated code might look:

    @staticmethod
    def class_int2str(data: pd.DataFrame, config: dict) -> pd.DataFrame:
        """Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
        if "confmap_config" in config:  # SLEAP
            try:
                heads = config["model"]["heads"]
                classes = util.find_nested_key(heads, "classes")
            except KeyError as err:
                raise KeyError(f"Cannot find classes in config.") from err
            for i, subj in enumerate(classes):
                data.loc[data["identity"] == i, "identity"] = subj
        return data

And then in the read method, you would call this function like so:

data = self.class_int2str(data, self.config)

This way, we're only reading the config file once and reusing it across multiple calls to class_int2str.

@jkbhagatio jkbhagatio merged commit cb7d84c into main Jul 3, 2024
5 of 7 checks passed
@@ -286,23 +286,35 @@ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/process
# `pattern` for this reader should typically be '<hpcnode>_<jobid>*'
super().__init__(pattern, columns=None)
self._model_root = model_root
self.config_file = None # requires reading the data file to be set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbhagatio @ttngu207 I don't think it makes sense to move config_file to a class field, because the config file only makes sense relative to a single shunk, i.e. a single call of the read method. The reason for this is we can change models and config files at any point during an experiment, e.g. when mice come in and out or if there is a model change.

What was the motivation behind this change? Is it necessary for some part of the ingestion pipeline?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glopesdev you are correct, making config_file a class property is misleading because the value of config_file is only valid within the read call and not reliably valid for the entire class.

The ingestion does need this but there are other workarounds - I'll make updates to fix this

columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)
try: # Bonsai.Sleap0.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I think we would prefer to try with the new data format first (as it will be the default one going forward) and fallback to the legacy format if it doesn't work.

Copy link
Member Author

@jkbhagatio jkbhagatio Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't do this because as currently implemented, if they are swapped, then we erroneously read 0.2 data using 0.3 format.

i.e. we are able to (erroneously) read 0.2 data using 0.3 format, but if we try vice versa, it appropriately fails

@jkbhagatio jkbhagatio deleted the bonsai-sleap0.3-PoseReader branch November 14, 2024 11:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants