Skip to content

Commit

Permalink
Handle sortings in which no units are detected (SpikeSortingV1) (#728)
Browse files Browse the repository at this point in the history
* Save LFP as pynwb.ecephys.LFP

* Fix formatting

* Fix formatting

* Fix SpikeSorting.get_sorting

---------

Co-authored-by: Eric Denovellis <[email protected]>
  • Loading branch information
khl02007 and edeno authored Dec 21, 2023
1 parent 136de86 commit 00bb398
Showing 1 changed file with 45 additions and 17 deletions.
62 changes: 45 additions & 17 deletions src/spyglass/spikesorting/v1/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,34 @@ def get_sorting(cls, key: dict) -> si.BaseSorting:
"""

recording_id = (
SpikeSortingRecording * SpikeSortingSelection & key
).fetch1("recording_id")
recording = SpikeSortingRecording.get_recording(
{"recording_id": recording_id}
)
sampling_frequency = recording.get_sampling_frequency()
analysis_file_name = (cls & key).fetch1("analysis_file_name")
analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
analysis_file_name
)
sorting = se.read_nwb_sorting(analysis_file_abs_path)
with pynwb.NWBHDF5IO(
analysis_file_abs_path, "r", load_namespaces=True
) as io:
nwbf = io.read()
units = nwbf.units.to_dataframe()
units_dict_list = [
{
unit_id: np.searchsorted(recording.get_times(), spike_times)
for unit_id, spike_times in zip(
units.index, units["spike_times"]
)
}
]

sorting = si.NumpySorting.from_unit_dict(
units_dict_list, sampling_frequency=sampling_frequency
)

return sorting

Expand Down Expand Up @@ -331,23 +354,28 @@ def _write_sorting_to_nwb(
load_namespaces=True,
) as io:
nwbf = io.read()
nwbf.add_unit_column(
name="curation_label",
description="curation label applied to a unit",
)
obs_interval = (
sort_interval
if sort_interval.ndim == 2
else sort_interval.reshape(1, 2)
)
for unit_id in sorting.get_unit_ids():
spike_times = sorting.get_unit_spike_train(unit_id)
nwbf.add_unit(
spike_times=timestamps[spike_times],
id=unit_id,
obs_intervals=obs_interval,
curation_label="uncurated",
if sorting.get_num_units() == 0:
nwbf.units = pynwb.misc.Units(
name="units", description="Empty units table."
)
else:
nwbf.add_unit_column(
name="curation_label",
description="curation label applied to a unit",
)
obs_interval = (
sort_interval
if sort_interval.ndim == 2
else sort_interval.reshape(1, 2)
)
for unit_id in sorting.get_unit_ids():
spike_times = sorting.get_unit_spike_train(unit_id)
nwbf.add_unit(
spike_times=timestamps[spike_times],
id=unit_id,
obs_intervals=obs_interval,
curation_label="uncurated",
)
units_object_id = nwbf.units.object_id
io.write(nwbf)
return analysis_nwb_file, units_object_id

0 comments on commit 00bb398

Please sign in to comment.