Skip to content

Commit

Permalink
Merge pull request #2864 from alejoe91/highpass-spatial-dtype
Browse files Browse the repository at this point in the history
Fix highpass-spatial-filter return dtype
  • Loading branch information
samuelgarcia authored May 21, 2024
2 parents 8317eb5 + 8b622cc commit 89968ca
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
11 changes: 10 additions & 1 deletion src/spikeinterface/preprocessing/highpass_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from .filter import fix_dtype
from ..core import order_channels_by_depth, get_chunk_with_margin
from ..core.core_tools import define_function_from_class

Expand Down Expand Up @@ -47,6 +48,8 @@ class HighpassSpatialFilterRecording(BasePreprocessor):
Order of spatial butterworth filter
highpass_butter_wn : float, default: 0.01
Critical frequency (with respect to Nyquist) of spatial butterworth filter
dtype : dtype, default: None
The dtype of the output traces. If None, the dtype is the same as the input traces
Returns
-------
Expand All @@ -73,6 +76,7 @@ def __init__(
agc_window_length_s=0.1,
highpass_butter_order=3,
highpass_butter_wn=0.01,
dtype=None,
):
BasePreprocessor.__init__(self, recording)

Expand Down Expand Up @@ -117,6 +121,8 @@ def __init__(
butter_kwargs = dict(btype="highpass", N=highpass_butter_order, Wn=highpass_butter_wn)
sos_filter = scipy.signal.butter(**butter_kwargs, output="sos")

dtype = fix_dtype(recording, dtype)

for parent_segment in recording._recording_segments:
rec_segment = HighPassSpatialFilterSegment(
parent_segment,
Expand All @@ -128,6 +134,7 @@ def __init__(
sos_filter,
order_f,
order_r,
dtype=dtype,
)
self.add_recording_segment(rec_segment)

Expand Down Expand Up @@ -155,6 +162,7 @@ def __init__(
sos_filter,
order_f,
order_r,
dtype,
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.parent_recording_segment = parent_recording_segment
Expand All @@ -178,6 +186,7 @@ def __init__(
self.order_r = order_r
# get filter params
self.sos_filter = sos_filter
self.dtype = dtype

def get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
Expand Down Expand Up @@ -234,7 +243,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = traces[left_margin:-right_margin, channel_indices]
else:
traces = traces[left_margin:, channel_indices]
return traces
return traces.astype(self.dtype, copy=False)


# function for API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,34 @@ def test_highpass_spatial_filter_synthetic_data(num_channels, ntr_pad, ntr_tap,
assert raw_traces.shape == si_filtered.shape


@pytest.mark.parametrize("dtype", [np.int16, np.float32, np.float64])
def test_dtype_stability(dtype):
"""
Check that the dtype of the recording and
output data is as expected, as data is cast to float32
during filtering.
"""
num_chan = 32
si_recording = generate_recording(num_channels=num_chan, durations=[2])
si_recording.set_property("gain_to_uV", np.ones(num_chan))
si_recording.set_property("offset_to_uV", np.ones(num_chan))
si_recording = spre.astype(si_recording, dtype)

assert si_recording.dtype == dtype

highpass_spatial_filter = spre.highpass_spatial_filter(si_recording, n_channel_pad=2)

assert highpass_spatial_filter.dtype == dtype

filtered_data_unscaled = highpass_spatial_filter.get_traces(return_scaled=False)

assert filtered_data_unscaled.dtype == dtype

filtered_data_scaled = highpass_spatial_filter.get_traces(return_scaled=True)

assert filtered_data_scaled.dtype == np.float32


# ----------------------------------------------------------------------------------------------------------------------
# Test Utils
# ----------------------------------------------------------------------------------------------------------------------
Expand All @@ -125,7 +153,7 @@ def get_ibl_si_data():
ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel

si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
si_recording = spre.scale(si_recording, dtype="float32")
si_recording = spre.astype(si_recording, dtype="float32")

return ibl_data, si_recording

Expand Down

0 comments on commit 89968ca

Please sign in to comment.