From 63b851c6842c61c448d151d3049e77faaaebd75e Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 31 May 2024 11:38:18 -0400 Subject: [PATCH] Delegate to sample_index_to_time() in estimation --- .../sortingcomponents/motion_estimation.py | 28 +++++++++---------- .../sortingcomponents/motion_utils.py | 2 +- .../tests/test_motion_estimation.py | 4 +-- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 3a8b75f8b3..bede0a19bb 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -683,16 +683,15 @@ def make_2d_motion_histogram( spatial_bin_edges 1d array with spatial bin edges """ - fs = recording.get_sampling_frequency() - num_samples = recording.get_num_samples(segment_index=0) - bin_sample_size = int(bin_duration_s * fs) - sample_bin_edges = np.arange(0, num_samples + bin_sample_size, bin_sample_size) - temporal_bin_edges = sample_bin_edges / fs + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) if spatial_bin_edges is None: spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) arr = np.zeros((peaks.size, 2), dtype="float64") - arr[:, 0] = peaks["sample_index"] + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) arr[:, 1] = peak_locations[direction] if weight_with_amplitude: @@ -700,11 +699,11 @@ def make_2d_motion_histogram( else: weights = None - motion_histogram, edges = np.histogramdd(arr, bins=(sample_bin_edges, spatial_bin_edges), weights=weights) + motion_histogram, edges = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges), weights=weights) # average amplitude in each bin if weight_with_amplitude: - bin_counts, _ = np.histogramdd(arr, bins=(sample_bin_edges, spatial_bin_edges)) + bin_counts, _ = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges)) bin_counts[bin_counts == 0] = 1 motion_histogram = motion_histogram / bin_counts @@ -759,11 +758,10 @@ def make_3d_motion_histograms( spatial_bin_edges 1d array with spatial bin edges """ - fs = recording.get_sampling_frequency() - num_samples = recording.get_num_samples(segment_index=0) - bin_sample_size = int(bin_duration_s * fs) - sample_bin_edges = np.arange(0, num_samples + bin_sample_size, bin_sample_size) - temporal_bin_edges = sample_bin_edges / fs + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) if spatial_bin_edges is None: spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) @@ -778,14 +776,14 @@ def make_3d_motion_histograms( ) arr = np.zeros((peaks.size, 3), dtype="float64") - arr[:, 0] = peaks["sample_index"] + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) arr[:, 1] = peak_locations[direction] arr[:, 2] = abs_peaks_log_norm motion_histograms, edges = np.histogramdd( arr, bins=( - sample_bin_edges, + temporal_bin_edges, spatial_bin_edges, amplitude_bin_edges, ), diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 9537b5bf1c..0f19c2a2de 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -17,7 +17,7 @@ # * generate drifting signals for test estimate_motion and interpolate_motion: SIMPLE ONE DONE? # * uncomment assert in test_estimate_motion (aka debug torch vs numpy diff): DONE # * delegate times to recording object in -# * estimate motion +# * estimate motion: DONE # * correct_motion_on_peaks() # * interpolate_motion_on_traces() # propagate to benchmark estimate motion diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 87534ec1bf..7eea4e0bdd 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -201,7 +201,7 @@ def test_estimate_motion(): assert motion0 == motion1 motion0 = motions["rigid / decentralized / torch / time_horizon_s"] - motion1 = motions["rigid / decentralized / numpy / time_horizon_s"], + motion1 = motions["rigid / decentralized / numpy / time_horizon_s"] np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) motion0 = motions["non-rigid / decentralized / torch"] @@ -209,7 +209,7 @@ def test_estimate_motion(): np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) motion0 = motions["non-rigid / decentralized / torch / time_horizon_s"] - motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"], + motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"] np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) motion0 = motions["non-rigid / decentralized / torch / spatial_prior"]