From b5c85ff41ea8c53030d7cd9df50767f9f9644742 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 14 Nov 2024 15:50:29 +0000 Subject: [PATCH] Working on the private PC decomposition. --- .../working/load_kilosort_utils.py | 19 +++++++++++++------ .../working/plot_kilosort_drift_map.py | 4 +++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/working/load_kilosort_utils.py b/src/spikeinterface/working/load_kilosort_utils.py index 01b2ac2b81..084c3b37ea 100644 --- a/src/spikeinterface/working/load_kilosort_utils.py +++ b/src/spikeinterface/working/load_kilosort_utils.py @@ -59,10 +59,7 @@ def compute_spike_amplitude_and_depth( center of mass from the first PC (or, second PC if no signal on first PC). See `_get_locations_from_pc_features()` for details. """ - if isinstance(sorter_output, str): - sorter_output = Path(sorter_output) - - if not params["pc_features"]: + if params["pc_features"] is None: raise ValueError("`pc_features` must be loaded into params. Use `load_ks_dir` with `load_pcs=True`.") if localised_spikes_only: @@ -118,10 +115,12 @@ def compute_spike_amplitude_and_depth( def _get_locations_from_pc_features(params): """ + Compute locations from the waveform principal component scores. Notes ----- Location of of each individual spike is computed from its low-dimensional projection. + During sorting, kilosort computes the ' `pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike. Taking the first component, the subset of 32 channels associated with this spike are indexed to get the actual channel locations (in um). Then, the channel @@ -131,6 +130,13 @@ def _get_locations_from_pc_features(params): https://github.com/cortex-lab/spikes """ # Compute spike depths + + # for each spike, a PCA is computed just on that spike (n samples x n channels). + # the components are all different between spikes, so are not saved. + # This gives a (n pc = 3, num channels) set of scores. + # but then how it is possible for some spikes to have zero score onto the principal channel? + + breakpoint() pc_features = params["pc_features"][:, 0, :] pc_features[pc_features < 0] = 0 @@ -153,7 +159,7 @@ def _get_locations_from_pc_features(params): pc_features[no_pc1_signal_spikes] = pc_features_2[no_pc1_signal_spikes] - if any(np.sum(pc_features, axis=1) == 0): + if np.any(np.sum(pc_features, axis=1) == 0): raise RuntimeError( "Some spikes do not load at all onto the first" "or second principal component. It is necessary" @@ -319,7 +325,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool As this function strips the spikes and units based on only these two data structures, they will work following manual reassignment in Phy. """ - sorter_output = Path(sorter_output) + if isinstance(sorter_output, str): + sorter_output = Path(sorter_output) params = read_python(sorter_output / "params.py") diff --git a/src/spikeinterface/working/plot_kilosort_drift_map.py b/src/spikeinterface/working/plot_kilosort_drift_map.py index ecac38495f..e61b7bddd9 100644 --- a/src/spikeinterface/working/plot_kilosort_drift_map.py +++ b/src/spikeinterface/working/plot_kilosort_drift_map.py @@ -106,8 +106,10 @@ def plot_matplotlib(self, data_plot: dict, **unused_kwargs) -> None: dp = to_attr(data_plot) + params = load_kilosort_utils.load_ks_dir(dp.sorter_output, load_pcs=True, exclude_noise=dp.exclude_noise) + spike_indexes, spike_amplitudes, spike_locations, _ = load_kilosort_utils.compute_spike_amplitude_and_depth( - dp.sorter_output, dp.localised_spikes_only, dp.exclude_noise, dp.gain, dp.localised_spikes_channel_cutoff + params, dp.localised_spikes_only, dp.gain, dp.localised_spikes_channel_cutoff ) spike_times = spike_indexes / 30000 spike_depths = spike_locations[:, 1]