Skip to content

Commit

Permalink
Add causal filter class and function to filter.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanPimientoCaicedo authored Jun 5, 2024
1 parent 8bbccab commit 12f3ad3
Showing 1 changed file with 48 additions and 2 deletions.
50 changes: 48 additions & 2 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class FilterRecording(BasePreprocessor):
Filter coefficients in the filter_mode form.
dtype : dtype or None, default: None
The dtype of the returned traces. If None, the dtype of the parent recording is used
causal : True or False, default : False
If True, Backward causal filtering is performed to correct hardware induced phase shift
{}
Returns
Expand All @@ -70,6 +72,7 @@ def __init__(
add_reflect_padding=False,
coeff=None,
dtype=None,
causal=False,
):
import scipy.signal

Expand Down Expand Up @@ -116,6 +119,7 @@ def __init__(
margin_ms=margin_ms,
add_reflect_padding=add_reflect_padding,
dtype=dtype.str,
causal=causal
)


Expand Down Expand Up @@ -146,10 +150,16 @@ def get_traces(self, start_frame, end_frame, channel_indices):
import scipy.signal

if self.filter_mode == "sos":
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
if causal:
filtered_traces = np.flip(scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk)))
else:
filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0)
elif self.filter_mode == "ba":
b, a = self.coeff
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)
if causal:
filtered_traces = np.flip(scipy.signal.lfilter(b, a, np.flip(traces_chunk), axis=0))
else:
filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0)

if right_margin > 0:
filtered_traces = filtered_traces[left_margin:-right_margin, :]
Expand Down Expand Up @@ -277,12 +287,48 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):

self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str)

class CausalFilterRecording(FilterRecording):
"""
Implements backwards causal filter to correct for hardware induced phase shift
Parameters
----------
recording : Recording
The recording extractor to be re-referenced
band : float or list, default : [300.0]
If float, cutoff frequency in Hz for "highpass" filter type
If list, band (low, high) in Hz for "bandpass" filter type
margin_ms : float
Margin in ms on border to avoid border effect
dtype : dtype or None
The dtype of the returned traces. If None, the dtype of the parent recording is used
btype : "bandpass" | "highpass", default: "highpass"
Type of the filter
filter_order : int, default : 1
filter order - the Neuropixels filter is a single pole RC filter.
{}
Returns
-------
filter_recording : CausalFilterRecording
The CausalFilterRecording recording extractor object
"""
name = "causal_filter"

def __init__(self, recording, band=[300.0], margin_ms=5.0, btype = "highpass",filter_order = 1, dtype=None,**filter_kwargs):
FilterRecording.__init__(
self, recording, band=band, margin_ms=margin_ms, dtype=dtype, btype = btype, filter_order = filter_order,causal = True,
**filter_kwargs)
dtype = fix_dtype(recording, dtype)
self._kwargs = dict(recording=recording, band=band, margin_ms=margin_ms, dtype=dtype.str)
self._kwargs.update(filter_kwargs)


# functions for API
filter = define_function_from_class(source_class=FilterRecording, name="filter")
bandpass_filter = define_function_from_class(source_class=BandpassFilterRecording, name="bandpass_filter")
notch_filter = define_function_from_class(source_class=NotchFilterRecording, name="notch_filter")
highpass_filter = define_function_from_class(source_class=HighpassFilterRecording, name="highpass_filter")
causal_filter = define_function_from_class(source_class=CausalFilterRecording, name="causal_filter")


def fix_dtype(recording, dtype):
Expand Down

0 comments on commit 12f3ad3

Please sign in to comment.