Skip to content

Commit

Permalink
Handle single intervals in spike sorting (#726)
Browse files Browse the repository at this point in the history
* Handle case where ther is only one interval

* Fix settings

* Handle single interval

* Add spyglass mixin
  • Loading branch information
edeno authored Dec 20, 2023
1 parent 3e3dc6f commit 136de86
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true
"source.organizeImports": "explicit"
},
},
"isort.args": [
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/spikesorting/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@schema
class SpikeSortingOutput(_Merge):
class SpikeSortingOutput(_Merge, SpyglassMixin):
definition = """
# Output of spike sorting pipelines.
merge_id: uuid
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ def _consolidate_intervals(intervals, timestamps):
"""
# Convert intervals to a numpy array if it's not
intervals = np.array(intervals)
if intervals.ndim == 1:
intervals = intervals.reshape(-1, 2)
if intervals.shape[1] != 2:
raise ValueError(
"Input array must have shape (N, 2) where N is the number of intervals."
Expand Down
7 changes: 6 additions & 1 deletion src/spyglass/spikesorting/v1/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,17 @@ def _write_sorting_to_nwb(
name="curation_label",
description="curation label applied to a unit",
)
obs_interval = (
sort_interval
if sort_interval.ndim == 2
else sort_interval.reshape(1, 2)
)
for unit_id in sorting.get_unit_ids():
spike_times = sorting.get_unit_spike_train(unit_id)
nwbf.add_unit(
spike_times=timestamps[spike_times],
id=unit_id,
obs_intervals=sort_interval,
obs_intervals=obs_interval,
curation_label="uncurated",
)
units_object_id = nwbf.units.object_id
Expand Down

0 comments on commit 136de86

Please sign in to comment.