Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/LorenFrankLab/spyglass in…
Browse files Browse the repository at this point in the history
…to msc
  • Loading branch information
CBroz1 committed Nov 27, 2024
2 parents 4f49973 + 6faed4c commit 94f7d15
Show file tree
Hide file tree
Showing 24 changed files with 997 additions and 261 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Add testing for python versions 3.9, 3.10, 3.11, 3.12 #1169
- Initialize tables in pytests #1181
- Download test data without credentials, trigger on approved PRs #1180
- Add coverage of decoding pipeline to pytests #1155
- Allow python \< 3.13 #1169
- Remove numpy version restriction #1169
- Merge table delete removes orphaned master entries #1164
Expand Down Expand Up @@ -85,6 +86,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Fix bug in `_compute_metric` #1099
- Fix bug in `insert_curation` returned key #1114
- Fix handling of waveform extraction sparse parameter #1132
- Limit Artifact detection intervals to valid times #1196

## [0.5.3] (August 27, 2024)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ omit = [ # which submodules have no tests
"*/cli/*",
# "*/common/*",
"*/data_import/*",
"*/decoding/*",
"*/decoding/v0/*",
# "*/decoding/*",
"*/figurl_views/*",
# "*/lfp/*",
# "*/linearization/*",
Expand Down
48 changes: 16 additions & 32 deletions src/spyglass/decoding/decoding_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,53 +85,41 @@ def cleanup(self, dry_run=False):
@classmethod
def fetch_results(cls, key):
"""Fetch the decoding results for a given key."""
return cls().merge_get_parent_class(key).fetch_results()
return cls().merge_restrict_class(key).fetch_results()

@classmethod
def fetch_model(cls, key):
"""Fetch the decoding model for a given key."""
return cls().merge_get_parent_class(key).fetch_model()
return cls().merge_restrict_class(key).fetch_model()

@classmethod
def fetch_environments(cls, key):
"""Fetch the decoding environments for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_environments(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_environments(decoding_selection_key)

@classmethod
def fetch_position_info(cls, key):
"""Fetch the decoding position info for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_position_info(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_position_info(decoding_selection_key)

@classmethod
def fetch_linear_position_info(cls, key):
"""Fetch the decoding linear position info for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_linear_position_info(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_linear_position_info(decoding_selection_key)

@classmethod
def fetch_spike_data(cls, key, filter_by_interval=True):
"""Fetch the decoding spike data for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_linear_position_info(
decoding_selection_key, filter_by_interval=filter_by_interval
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_spike_data(
decoding_selection_key, filter_by_interval=filter_by_interval
)

@classmethod
Expand Down Expand Up @@ -167,11 +155,7 @@ def create_decoding_view(cls, key, head_direction_name="head_orientation"):
head_dir=position_info[head_direction_name],
)
else:
(
position_info,
position_variable_names,
) = cls.fetch_linear_position_info(key)
return create_1D_decode_view(
posterior=posterior,
linear_position=position_info["linear_position"],
linear_position=cls.fetch_linear_position_info(key),
)
10 changes: 6 additions & 4 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def create_group(
"waveform_features_group_name": group_name,
}
if self & group_key:
raise ValueError(
f"Group {nwb_file_name}: {group_name} already exists",
"please delete the group before creating a new one",
logger.error( # No error on duplicate helps with pytests
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one",
)
return
self.insert1(
group_key,
skip_duplicates=True,
Expand Down Expand Up @@ -586,7 +587,8 @@ def get_ahead_behind_distance(self, track_graph=None, time_slice=None):
classifier.environments[0].track_graph, *traj_data
)
else:
position_info = self.fetch_position_info(self.fetch1("KEY")).loc[
# `fetch_position_info` returns a tuple
position_info = self.fetch_position_info(self.fetch1("KEY"))[0].loc[
time_slice
]
map_position = analysis.maximum_a_posteriori_estimate(posterior)
Expand Down
18 changes: 10 additions & 8 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
restore_classes,
)
from spyglass.position.position_merge import PositionOutput # noqa: F401
from spyglass.utils import SpyglassMixin, SpyglassMixinPart
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger

schema = dj.schema("decoding_core_v1")

Expand Down Expand Up @@ -56,14 +56,15 @@ class DecodingParameters(SpyglassMixin, dj.Lookup):
@classmethod
def insert_default(cls):
"""Insert default decoding parameters"""
cls.insert(cls.contents, skip_duplicates=True)
cls.super().insert(cls.contents, skip_duplicates=True)

def insert(self, rows, *args, **kwargs):
"""Override insert to convert classes to dict before inserting"""
for row in rows:
row["decoding_params"] = convert_classes_to_dict(
vars(row["decoding_params"])
)
params = row["decoding_params"]
if hasattr(params, "__dict__"):
params = vars(params)
row["decoding_params"] = convert_classes_to_dict(params)
super().insert(rows, *args, **kwargs)

def fetch(self, *args, **kwargs):
Expand Down Expand Up @@ -124,10 +125,11 @@ def create_group(
"position_group_name": group_name,
}
if self & group_key:
raise ValueError(
f"Group {nwb_file_name}: {group_name} already exists",
"please delete the group before creating a new one",
logger.error( # Easier for pytests to not raise error on duplicate
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one"
)
return
self.insert1(
{
**group_key,
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/spikesorting/analysis/v1/unit_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def add_annotation(self, key, **kwargs):
).fetch_nwb()[0]
nwb_field_name = _get_spike_obj_name(nwb_file)
spikes = nwb_file[nwb_field_name]["spike_times"].to_list()
if key["unit_id"] > len(spikes):
if key["unit_id"] > len(spikes) and not self._test_mode:
raise ValueError(
f"unit_id {key['unit_id']} is greater than ",
f"the number of units in {key['spikesorting_merge_id']}",
Expand Down
5 changes: 4 additions & 1 deletion src/spyglass/spikesorting/v0/spikesorting_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ def _get_artifact_times(
for interval_idx, interval in enumerate(artifact_intervals):
artifact_intervals_s[interval_idx] = [
valid_timestamps[interval[0]] - half_removal_window_s,
valid_timestamps[interval[1]] + half_removal_window_s,
np.minimum(
valid_timestamps[interval[1]] + half_removal_window_s,
valid_timestamps[-1],
),
]
# make the artifact intervals disjoint
if len(artifact_intervals_s) > 1:
Expand Down
5 changes: 4 additions & 1 deletion src/spyglass/spikesorting/v1/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ def _get_artifact_times(
),
np.searchsorted(
valid_timestamps,
valid_timestamps[interval[1]] + half_removal_window_s,
np.minimum(
valid_timestamps[interval[1]] + half_removal_window_s,
valid_timestamps[-1],
),
),
]
artifact_intervals_s[interval_idx] = [
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ def merge_get_parent_class(self, source: str) -> dj.Table:
source: Union[str, dict, dj.Table]
Accepts a CamelCase name of the source, or key as a dict, or a part
table.
init: bool, optional
Default False. If True, returns an instance of the class.
Returns
-------
Expand Down
7 changes: 5 additions & 2 deletions tests/common/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def interval_list(common):


def test_plot_intervals(mini_insert, interval_list):
fig = interval_list.plot_intervals(return_fig=True)
fig = (interval_list & 'interval_list_name LIKE "raw%"').plot_intervals(
return_fig=True
)
interval_list_name = fig.get_axes()[0].get_yticklabels()[0].get_text()
times_fetch = (
interval_list & {"interval_list_name": interval_list_name}
Expand All @@ -19,7 +21,8 @@ def test_plot_intervals(mini_insert, interval_list):


def test_plot_epoch(mini_insert, interval_list):
fig = interval_list.plot_epoch_pos_raw_intervals(return_fig=True)
restr_interval = interval_list & "interval_list_name like 'raw%'"
fig = restr_interval.plot_epoch_pos_raw_intervals(return_fig=True)
epoch_label = fig.get_axes()[0].get_yticklabels()[-1].get_text()
assert epoch_label == "epoch", "plot_epoch failed"

Expand Down
Loading

0 comments on commit 94f7d15

Please sign in to comment.