diff --git a/.vscode/settings.json b/.vscode/settings.json index f94239ef5..6e239d6af 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,7 +10,7 @@ "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" }, }, "isort.args": [ diff --git a/src/spyglass/spikesorting/merge.py b/src/spyglass/spikesorting/merge.py index ff709ebee..7e32b898b 100644 --- a/src/spyglass/spikesorting/merge.py +++ b/src/spyglass/spikesorting/merge.py @@ -12,7 +12,7 @@ @schema -class SpikeSortingOutput(_Merge): +class SpikeSortingOutput(_Merge, SpyglassMixin): definition = """ # Output of spike sorting pipelines. merge_id: uuid diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 053bc85aa..3e279a71c 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -554,6 +554,8 @@ def _consolidate_intervals(intervals, timestamps): """ # Convert intervals to a numpy array if it's not intervals = np.array(intervals) + if intervals.ndim == 1: + intervals = intervals.reshape(-1, 2) if intervals.shape[1] != 2: raise ValueError( "Input array must have shape (N, 2) where N is the number of intervals." diff --git a/src/spyglass/spikesorting/v1/sorting.py b/src/spyglass/spikesorting/v1/sorting.py index 119054d50..a91cc6914 100644 --- a/src/spyglass/spikesorting/v1/sorting.py +++ b/src/spyglass/spikesorting/v1/sorting.py @@ -335,12 +335,17 @@ def _write_sorting_to_nwb( name="curation_label", description="curation label applied to a unit", ) + obs_interval = ( + sort_interval + if sort_interval.ndim == 2 + else sort_interval.reshape(1, 2) + ) for unit_id in sorting.get_unit_ids(): spike_times = sorting.get_unit_spike_train(unit_id) nwbf.add_unit( spike_times=timestamps[spike_times], id=unit_id, - obs_intervals=sort_interval, + obs_intervals=obs_interval, curation_label="uncurated", ) units_object_id = nwbf.units.object_id