diff --git a/src/spyglass/spikesorting/v1/sorting.py b/src/spyglass/spikesorting/v1/sorting.py index a91cc6914..f2827ffc0 100644 --- a/src/spyglass/spikesorting/v1/sorting.py +++ b/src/spyglass/spikesorting/v1/sorting.py @@ -290,11 +290,34 @@ def get_sorting(cls, key: dict) -> si.BaseSorting: """ + recording_id = ( + SpikeSortingRecording * SpikeSortingSelection & key + ).fetch1("recording_id") + recording = SpikeSortingRecording.get_recording( + {"recording_id": recording_id} + ) + sampling_frequency = recording.get_sampling_frequency() analysis_file_name = (cls & key).fetch1("analysis_file_name") analysis_file_abs_path = AnalysisNwbfile.get_abs_path( analysis_file_name ) - sorting = se.read_nwb_sorting(analysis_file_abs_path) + with pynwb.NWBHDF5IO( + analysis_file_abs_path, "r", load_namespaces=True + ) as io: + nwbf = io.read() + units = nwbf.units.to_dataframe() + units_dict_list = [ + { + unit_id: np.searchsorted(recording.get_times(), spike_times) + for unit_id, spike_times in zip( + units.index, units["spike_times"] + ) + } + ] + + sorting = si.NumpySorting.from_unit_dict( + units_dict_list, sampling_frequency=sampling_frequency + ) return sorting @@ -331,23 +354,28 @@ def _write_sorting_to_nwb( load_namespaces=True, ) as io: nwbf = io.read() - nwbf.add_unit_column( - name="curation_label", - description="curation label applied to a unit", - ) - obs_interval = ( - sort_interval - if sort_interval.ndim == 2 - else sort_interval.reshape(1, 2) - ) - for unit_id in sorting.get_unit_ids(): - spike_times = sorting.get_unit_spike_train(unit_id) - nwbf.add_unit( - spike_times=timestamps[spike_times], - id=unit_id, - obs_intervals=obs_interval, - curation_label="uncurated", + if sorting.get_num_units() == 0: + nwbf.units = pynwb.misc.Units( + name="units", description="Empty units table." ) + else: + nwbf.add_unit_column( + name="curation_label", + description="curation label applied to a unit", + ) + obs_interval = ( + sort_interval + if sort_interval.ndim == 2 + else sort_interval.reshape(1, 2) + ) + for unit_id in sorting.get_unit_ids(): + spike_times = sorting.get_unit_spike_train(unit_id) + nwbf.add_unit( + spike_times=timestamps[spike_times], + id=unit_id, + obs_intervals=obs_interval, + curation_label="uncurated", + ) units_object_id = nwbf.units.object_id io.write(nwbf) return analysis_nwb_file, units_object_id