Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytests for decoding pipeline #1155

Merged
merged 10 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Add testing for python versions 3.9, 3.10, 3.11, 3.12 #1169
- Initialize tables in pytests #1181
- Download test data without credentials, trigger on approved PRs #1180
- Add coverage of decoding pipeline to pytests #1155
- Allow python \< 3.13 #1169
- Remove numpy version restriction #1169
- Merge table delete removes orphaned master entries #1164
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ omit = [ # which submodules have no tests
"*/cli/*",
# "*/common/*",
"*/data_import/*",
"*/decoding/*",
"*/decoding/v0/*",
# "*/decoding/*",
"*/figurl_views/*",
# "*/lfp/*",
# "*/linearization/*",
Expand Down
48 changes: 16 additions & 32 deletions src/spyglass/decoding/decoding_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,53 +85,41 @@ def cleanup(self, dry_run=False):
@classmethod
def fetch_results(cls, key):
"""Fetch the decoding results for a given key."""
return cls().merge_get_parent_class(key).fetch_results()
return cls().merge_restrict_class(key).fetch_results()

@classmethod
def fetch_model(cls, key):
"""Fetch the decoding model for a given key."""
return cls().merge_get_parent_class(key).fetch_model()
return cls().merge_restrict_class(key).fetch_model()

@classmethod
def fetch_environments(cls, key):
"""Fetch the decoding environments for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_environments(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_environments(decoding_selection_key)

@classmethod
def fetch_position_info(cls, key):
"""Fetch the decoding position info for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_position_info(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_position_info(decoding_selection_key)

@classmethod
def fetch_linear_position_info(cls, key):
"""Fetch the decoding linear position info for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_linear_position_info(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_linear_position_info(decoding_selection_key)

@classmethod
def fetch_spike_data(cls, key, filter_by_interval=True):
"""Fetch the decoding spike data for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_linear_position_info(
decoding_selection_key, filter_by_interval=filter_by_interval
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_spike_data(
decoding_selection_key, filter_by_interval=filter_by_interval
)

@classmethod
Expand Down Expand Up @@ -167,11 +155,7 @@ def create_decoding_view(cls, key, head_direction_name="head_orientation"):
head_dir=position_info[head_direction_name],
)
else:
(
position_info,
position_variable_names,
) = cls.fetch_linear_position_info(key)
return create_1D_decode_view(
posterior=posterior,
linear_position=position_info["linear_position"],
linear_position=cls.fetch_linear_position_info(key),
)
10 changes: 6 additions & 4 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def create_group(
"waveform_features_group_name": group_name,
}
if self & group_key:
raise ValueError(
f"Group {nwb_file_name}: {group_name} already exists",
"please delete the group before creating a new one",
logger.error( # No error on duplicate helps with pytests
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one",
)
return
self.insert1(
group_key,
skip_duplicates=True,
Expand Down Expand Up @@ -586,7 +587,8 @@ def get_ahead_behind_distance(self, track_graph=None, time_slice=None):
classifier.environments[0].track_graph, *traj_data
)
else:
position_info = self.fetch_position_info(self.fetch1("KEY")).loc[
# `fetch_position_info` returns a tuple
position_info = self.fetch_position_info(self.fetch1("KEY"))[0].loc[
edeno marked this conversation as resolved.
Show resolved Hide resolved
time_slice
]
map_position = analysis.maximum_a_posteriori_estimate(posterior)
Expand Down
18 changes: 10 additions & 8 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
restore_classes,
)
from spyglass.position.position_merge import PositionOutput # noqa: F401
from spyglass.utils import SpyglassMixin, SpyglassMixinPart
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger

schema = dj.schema("decoding_core_v1")

Expand Down Expand Up @@ -56,14 +56,15 @@ class DecodingParameters(SpyglassMixin, dj.Lookup):
@classmethod
def insert_default(cls):
"""Insert default decoding parameters"""
cls.insert(cls.contents, skip_duplicates=True)
cls.super().insert(cls.contents, skip_duplicates=True)

def insert(self, rows, *args, **kwargs):
"""Override insert to convert classes to dict before inserting"""
for row in rows:
row["decoding_params"] = convert_classes_to_dict(
vars(row["decoding_params"])
)
params = row["decoding_params"]
if hasattr(params, "__dict__"):
params = vars(params)
row["decoding_params"] = convert_classes_to_dict(params)
super().insert(rows, *args, **kwargs)

def fetch(self, *args, **kwargs):
Expand Down Expand Up @@ -124,10 +125,11 @@ def create_group(
"position_group_name": group_name,
}
if self & group_key:
raise ValueError(
f"Group {nwb_file_name}: {group_name} already exists",
"please delete the group before creating a new one",
logger.error( # Easier for pytests to not raise error on duplicate
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one"
)
return
self.insert1(
{
**group_key,
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/spikesorting/analysis/v1/unit_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def add_annotation(self, key, **kwargs):
).fetch_nwb()[0]
nwb_field_name = _get_spike_obj_name(nwb_file)
spikes = nwb_file[nwb_field_name]["spike_times"].to_list()
if key["unit_id"] > len(spikes):
if key["unit_id"] > len(spikes) and not self._test_mode:
raise ValueError(
f"unit_id {key['unit_id']} is greater than ",
f"the number of units in {key['spikesorting_merge_id']}",
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ def merge_get_parent_class(self, source: str) -> dj.Table:
source: Union[str, dict, dj.Table]
Accepts a CamelCase name of the source, or key as a dict, or a part
table.
init: bool, optional
Default False. If True, returns an instance of the class.

Returns
-------
Expand Down
7 changes: 5 additions & 2 deletions tests/common/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def interval_list(common):


def test_plot_intervals(mini_insert, interval_list):
fig = interval_list.plot_intervals(return_fig=True)
fig = (interval_list & 'interval_list_name LIKE "raw%"').plot_intervals(
return_fig=True
)
interval_list_name = fig.get_axes()[0].get_yticklabels()[0].get_text()
times_fetch = (
interval_list & {"interval_list_name": interval_list_name}
Expand All @@ -19,7 +21,8 @@ def test_plot_intervals(mini_insert, interval_list):


def test_plot_epoch(mini_insert, interval_list):
fig = interval_list.plot_epoch_pos_raw_intervals(return_fig=True)
restr_interval = interval_list & "interval_list_name like 'raw%'"
fig = restr_interval.plot_epoch_pos_raw_intervals(return_fig=True)
epoch_label = fig.get_axes()[0].get_yticklabels()[-1].get_text()
assert epoch_label == "epoch", "plot_epoch failed"

Expand Down
173 changes: 173 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,3 +1299,176 @@ def dlc_key(sgp, dlc_selection):
def populate_dlc(sgp, dlc_key):
sgp.v1.DLCPosV1().populate(dlc_key)
yield


# ----------------------- FIXTURES, SPIKESORTING TABLES -----------------------
# ------------------------ Note: Used in decoding tests ------------------------


@pytest.fixture(scope="session")
def spike_v1(common):
from spyglass.spikesorting import v1

yield v1


@pytest.fixture(scope="session")
def pop_rec(spike_v1, mini_dict, team_name):
spike_v1.SortGroup.set_group_by_shank(**mini_dict)
key = {
**mini_dict,
"sort_group_id": 0,
"preproc_param_name": "default",
"interval_list_name": "01_s1",
"team_name": team_name,
}
spike_v1.SpikeSortingRecordingSelection.insert_selection(key)
ssr_pk = (
(spike_v1.SpikeSortingRecordingSelection & key).proj().fetch1("KEY")
)
spike_v1.SpikeSortingRecording.populate(ssr_pk)

yield ssr_pk


@pytest.fixture(scope="session")
def pop_art(spike_v1, mini_dict, pop_rec):
key = {
"recording_id": pop_rec["recording_id"],
"artifact_param_name": "default",
}
spike_v1.ArtifactDetectionSelection.insert_selection(key)
spike_v1.ArtifactDetection.populate()

yield spike_v1.ArtifactDetection().fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
def spike_merge(spike_v1):
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput

yield SpikeSortingOutput()


@pytest.fixture(scope="session")
def sorter_dict():
return {"sorter": "mountainsort4"}


@pytest.fixture(scope="session")
def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict):
key = {
**mini_dict,
**sorter_dict,
"recording_id": pop_rec["recording_id"],
"interval_list_name": str(pop_art["artifact_id"]),
"sorter_param_name": "franklab_tetrode_hippocampus_30KHz",
}
spike_v1.SpikeSortingSelection.insert_selection(key)
spike_v1.SpikeSorting.populate()

yield spike_v1.SpikeSorting().fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
def sorting_objs(spike_v1, pop_sort):
sort_nwb = (spike_v1.SpikeSorting & pop_sort).fetch_nwb()
sort_si = spike_v1.SpikeSorting.get_sorting(pop_sort)
yield sort_nwb, sort_si


@pytest.fixture(scope="session")
def pop_curation(spike_v1, pop_sort):
spike_v1.CurationV1.insert_curation(
sorting_id=pop_sort["sorting_id"],
description="testing sort",
)

yield (spike_v1.CurationV1() & {"parent_curation_id": -1}).fetch(
"KEY", as_dict=True
)[0]


@pytest.fixture(scope="session")
def pop_metric(spike_v1, pop_sort, pop_curation):
_ = pop_curation # make sure this happens first
key = {
"sorting_id": pop_sort["sorting_id"],
"curation_id": 0,
"waveform_param_name": "default_not_whitened",
"metric_param_name": "franklab_default",
"metric_curation_param_name": "default",
}

spike_v1.MetricCurationSelection.insert_selection(key)
spike_v1.MetricCuration.populate(key)

yield spike_v1.MetricCuration().fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
def metric_objs(spike_v1, pop_metric):
key = {"metric_curation_id": pop_metric["metric_curation_id"]}
labels = spike_v1.MetricCuration.get_labels(key)
merge_groups = spike_v1.MetricCuration.get_merge_groups(key)
metrics = spike_v1.MetricCuration.get_metrics(key)
yield labels, merge_groups, metrics


@pytest.fixture(scope="session")
def pop_curation_metric(spike_v1, pop_metric, metric_objs):
labels, merge_groups, metrics = metric_objs
parent_dict = {"parent_curation_id": 0}
spike_v1.CurationV1.insert_curation(
sorting_id=(
spike_v1.MetricCurationSelection
& {"metric_curation_id": pop_metric["metric_curation_id"]}
).fetch1("sorting_id"),
**parent_dict,
labels=labels,
merge_groups=merge_groups,
metrics=metrics,
description="after metric curation",
)

yield (spike_v1.CurationV1 & parent_dict).fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
def pop_spike_merge(
spike_v1, pop_curation_metric, spike_merge, mini_dict, sorter_dict
):
# TODO: add figurl fixtures when kachery_cloud is initialized

spike_merge.insert([pop_curation_metric], part_name="CurationV1")

yield (spike_merge << pop_curation_metric).fetch1("KEY")


@pytest.fixture(scope="session")
def spike_v1_group():
from spyglass.spikesorting.analysis.v1 import group

yield group


@pytest.fixture(scope="session")
def group_name():
yield "test_group"


@pytest.fixture(scope="session")
def pop_spikes_group(
group_name, spike_v1_group, spike_merge, mini_dict, pop_spike_merge
):

_ = pop_spike_merge # make sure this happens first

spike_v1_group.UnitSelectionParams().insert_default()
spike_v1_group.SortedSpikesGroup().create_group(
**mini_dict,
group_name=group_name,
keys=spike_merge.proj(spikesorting_merge_id="merge_id").fetch("KEY"),
unit_filter_params_name="default_exclusion",
)
yield spike_v1_group.SortedSpikesGroup().fetch("KEY", as_dict=True)[0]
Empty file added tests/decoding/__init__.py
Empty file.
Loading
Loading