From c2a01723bc65314d0546c37c414fc013d51b1c83 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Mon, 17 Jun 2024 17:46:26 -0500 Subject: [PATCH] Change Filter Shallow Copies to Deep Copies (#175) * Change filter copies to deepcopies * Add unit test --- thicket/tests/test_filter_metadata.py | 3 +++ thicket/tests/test_filter_stats.py | 4 ++++ thicket/thicket.py | 8 ++++---- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/thicket/tests/test_filter_metadata.py b/thicket/tests/test_filter_metadata.py index 3b833c94..3c63d95a 100644 --- a/thicket/tests/test_filter_metadata.py +++ b/thicket/tests/test_filter_metadata.py @@ -32,6 +32,9 @@ def filter_one_column(th, columns_values): # check if output is a thicket object assert isinstance(new_th, Thicket) + # check filtered Thicket is separate object + assert th.graph is not new_th.graph + # metadata table: compare profile hash keys after filter to expected metadata_profile = new_th.metadata.index.tolist() assert metadata_profile == exp_index diff --git a/thicket/tests/test_filter_stats.py b/thicket/tests/test_filter_stats.py index 1e3acc92..7a2074d2 100644 --- a/thicket/tests/test_filter_stats.py +++ b/thicket/tests/test_filter_stats.py @@ -35,6 +35,10 @@ def check_filter_stats(th, columns_values): # check if output is a thicket object assert isinstance(new_th, Thicket) + # check filtered Thicket is separate object + # We can't check th.graph because of squash in filter_stats + assert th.statsframe.graph is not new_th.statsframe.graph + # filtered nodes in aggregated statistics table stats_nodes = sorted( new_th.statsframe.dataframe.index.drop_duplicates().tolist() diff --git a/thicket/thicket.py b/thicket/thicket.py index fbd893a4..98bc093c 100644 --- a/thicket/thicket.py +++ b/thicket/thicket.py @@ -1090,8 +1090,8 @@ def filter_metadata(self, select_function): # Get index name index_name = self.metadata.index.name - # create a copy of the thicket object - new_thicket = self.copy() + # create a deepcopy of the thicket object + new_thicket = self.deepcopy() # filter metadata table filtered_rows = new_thicket.metadata.apply(select_function, axis=1) @@ -1242,8 +1242,8 @@ def filter_stats(self, filter_function): Returns: (thicket): new thicket object with applied filter function """ - # copy thicket - new_thicket = self.copy() + # deepcopy thicket + new_thicket = self.deepcopy() # filter aggregated statistics table filtered_rows = new_thicket.statsframe.dataframe.apply(filter_function, axis=1)