From ea0241c36fe7bf05371e9729605666d081d0b32e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 15:19:51 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/filter.py | 30 +++++++++++++++------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 8ce1312088..4778d4c1dc 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -119,7 +119,7 @@ def __init__( margin_ms=margin_ms, add_reflect_padding=add_reflect_padding, dtype=dtype.str, - causal=causal + causal=causal, ) @@ -151,15 +151,15 @@ def get_traces(self, start_frame, end_frame, channel_indices): if self.filter_mode == "sos": if causal: - filtered_traces = np.flip(scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk))) + filtered_traces = np.flip(scipy.signal.sosfilt(self.coeff, np.flip(traces_chunk))) else: - filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) + filtered_traces = scipy.signal.sosfiltfilt(self.coeff, traces_chunk, axis=0) elif self.filter_mode == "ba": b, a = self.coeff if causal: - filtered_traces = np.flip(scipy.signal.lfilter(b, a, np.flip(traces_chunk), axis=0)) + filtered_traces = np.flip(scipy.signal.lfilter(b, a, np.flip(traces_chunk), axis=0)) else: - filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) + filtered_traces = scipy.signal.filtfilt(b, a, traces_chunk, axis=0) if right_margin > 0: filtered_traces = filtered_traces[left_margin:-right_margin, :] @@ -287,10 +287,11 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str) + class CausalFilterRecording(FilterRecording): """ Implements backwards causal filter to correct for hardware induced phase shift - + Parameters ---------- recording : Recording @@ -312,12 +313,23 @@ class CausalFilterRecording(FilterRecording): filter_recording : CausalFilterRecording The CausalFilterRecording recording extractor object """ + name = "causal_filter" - def __init__(self, recording, band=[300.0], margin_ms=5.0, btype = "highpass",filter_order = 1, dtype=None,**filter_kwargs): + def __init__( + self, recording, band=[300.0], margin_ms=5.0, btype="highpass", filter_order=1, dtype=None, **filter_kwargs + ): FilterRecording.__init__( - self, recording, band=band, margin_ms=margin_ms, dtype=dtype, btype = btype, filter_order = filter_order,causal = True, - **filter_kwargs) + self, + recording, + band=band, + margin_ms=margin_ms, + dtype=dtype, + btype=btype, + filter_order=filter_order, + causal=True, + **filter_kwargs, + ) dtype = fix_dtype(recording, dtype) self._kwargs = dict(recording=recording, band=band, margin_ms=margin_ms, dtype=dtype.str) self._kwargs.update(filter_kwargs)