diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 3f1a155d0d..8ce1312088 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -47,6 +47,8 @@ class FilterRecording(BasePreprocessor): Filter coefficients in the filter_mode form. dtype : dtype or None, default: None The dtype of the returned traces. If None, the dtype of the parent recording is used + causal : True or False, default : False + If True, Backward causal filtering is performed to correct hardware induced phase shift {} Returns @@ -70,6 +72,7 @@ def __init__( add_reflect_padding=False, coeff=None, dtype=None, + causal=False, ): import scipy.signal @@ -116,6 +119,7 @@ def __init__( margin_ms=margin_ms, add_reflect_padding=add_reflect_padding, dtype=dtype.str, + causal=causal ) @@ -146,10 +150,16 @@ def get_traces(self, start_frame, end_frame, channel_indices): import scipy.signal if self.filter_mode == "sos": - filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) + if causal: + filtered_traces = np.flip(scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk))) + else: + filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) elif self.filter_mode == "ba": b, a = self.coeff - filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + if causal: + filtered_traces = np.flip(scipy.signal.lfilter(b, a, np.flip(traces_chunk), axis=0)) + else: + filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) if right_margin > 0: filtered_traces = filtered_traces[left_margin:-right_margin, :] @@ -277,12 +287,48 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str) +class CausalFilterRecording(FilterRecording): + """ + Implements backwards causal filter to correct for hardware induced phase shift + + Parameters + ---------- + recording : Recording + The recording extractor to be re-referenced + band : float or list, default : [300.0] + If float, cutoff frequency in Hz for "highpass" filter type + If list, band (low, high) in Hz for "bandpass" filter type + margin_ms : float + Margin in ms on border to avoid border effect + dtype : dtype or None + The dtype of the returned traces. If None, the dtype of the parent recording is used + btype : "bandpass" | "highpass", default: "highpass" + Type of the filter + filter_order : int, default : 1 + filter order - the Neuropixels filter is a single pole RC filter. + {} + Returns + ------- + filter_recording : CausalFilterRecording + The CausalFilterRecording recording extractor object + """ + name = "causal_filter" + + def __init__(self, recording, band=[300.0], margin_ms=5.0, btype = "highpass",filter_order = 1, dtype=None,**filter_kwargs): + FilterRecording.__init__( + self, recording, band=band, margin_ms=margin_ms, dtype=dtype, btype = btype, filter_order = filter_order,causal = True, + **filter_kwargs) + dtype = fix_dtype(recording, dtype) + self._kwargs = dict(recording=recording, band=band, margin_ms=margin_ms, dtype=dtype.str) + self._kwargs.update(filter_kwargs) + # functions for API filter = define_function_from_class(source_class=FilterRecording, name="filter") bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter") notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter") highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter") +causal_filter = define_function_from_class(source_class=CausalFilterRecording, name="causal_filter") def fix_dtype(recording, dtype):