Skip to content

Commit

Permalink
delete pp_name_to_function
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Jan 23, 2025
1 parent 8436eb2 commit 5c19765
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 54 deletions.
30 changes: 15 additions & 15 deletions src/spikeinterface/preprocessing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
from copy import deepcopy
from spikeinterface.core.core_tools import is_dict_extractor
from spikeinterface.preprocessing.preprocessinglist import pp_function_to_class, preprocesser_dict, pp_name_to_function
from spikeinterface.preprocessing.preprocessinglist import pp_function_to_class, preprocesser_dict


class PreprocessingPipeline:
Expand All @@ -28,17 +28,17 @@ class PreprocessingPipeline:
"""

def __init__(self, preprocessor_list):
for preprocessor in preprocessor_list:
def __init__(self, preprocessor_dict):
for preprocessor in preprocessor_dict:
assert _is_genuine_preprocessor(
preprocessor
), f"'{preprocessor}' is not a preprocessing step in spikeinterface. To see the full list run:\n\t>>> from spikeinterface.preprocessing import pp_function_to_class\n\t>>> print(pp_function_to_class.keys())"

self.preprocessor_list = preprocessor_list
self.preprocessor_dict = preprocessor_dict

def __repr__(self):
txt = "PreprocessingPipeline: \tRaw Recording \u2192 "
for preprocessor in self.preprocessor_list:
for preprocessor in self.preprocessor_dict:
txt += str(preprocessor) + " \u2192 "
txt += "Preprocessed Recording"
return txt
Expand Down Expand Up @@ -90,9 +90,9 @@ def apply_to(self, recording):
"""

preprocessor_list = self.preprocessor_list
preprocessor_dict = self.preprocessor_dict

for preprocessor, kwargs in preprocessor_list.items():
for preprocessor, kwargs in preprocessor_dict.items():

kwargs.pop("recording", kwargs)
kwargs.pop("parent_recording", kwargs)
Expand All @@ -101,7 +101,7 @@ def apply_to(self, recording):
if using_class_name is True:
pp_output = preprocesser_dict[preprocessor.split(".")[-1]](recording, **kwargs)
else:
pp_output = pp_name_to_function[preprocessor.split(".")[-1]](recording, **kwargs)
pp_output = pp_function_to_class[preprocessor.split(".")[-1]](recording, **kwargs)

if preprocessor == "motion_correct":
pp_output = pp_output[0]
Expand All @@ -111,7 +111,7 @@ def apply_to(self, recording):
return recording


def create_preprocessed(recording=None, preprocessor_dict=None):
def create_preprocessed(recording, preprocessor_dict=None):
"""
Creates a preprocessed recording by applying the preprocessing steps in
`preprocessor_dict` to `recording`.
Expand All @@ -134,9 +134,9 @@ def create_preprocessed(recording=None, preprocessor_dict=None):
>>> from spikeinterface.preprocessing import create_preprocessed
>>> from spikeinterface.generation import generate_recording
>>> rec = generate_recording()
>>> recording = generate_recording()
>>> preprocessor_dict = {'bandpass_filter': {'freq_max': 3000}, 'common_reference': {}}
>>> preprocessed_rec = create_preprocessed(rec, preprocessor_dict)
>>> preprocessed_recording = create_preprocessed(recording, preprocessor_dict)
"""
Expand Down Expand Up @@ -205,7 +205,7 @@ def _is_genuine_preprocessor(preprocessor):
if using_class_name:
genuine_preprocessor = preprocessor in preprocesser_dict.keys()
else:
genuine_preprocessor = preprocessor in pp_name_to_function.keys()
genuine_preprocessor = preprocessor in pp_function_to_class.keys()

return genuine_preprocessor

Expand Down Expand Up @@ -234,10 +234,10 @@ def _load_pp_from_dict(prov_dict, kwargs_dict):
def _get_all_kwargs_and_values(my_pipeline):

all_kwargs = {}
for preprocessor in my_pipeline.preprocessor_list:
for preprocessor in my_pipeline.preprocessor_dict:

preprocessor_name = preprocessor.split(".")[-1]
pp_function = pp_name_to_function[preprocessor.split(".")[-1]]
pp_function = pp_function_to_class[preprocessor.split(".")[-1]]
signature = inspect.signature(pp_function)

all_kwargs[preprocessor_name] = {}
Expand All @@ -254,7 +254,7 @@ def _get_all_kwargs_and_values(my_pipeline):
except:
default_value = None

pipeline_value = my_pipeline.preprocessor_list[preprocessor].get(par_name)
pipeline_value = my_pipeline.preprocessor_dict[preprocessor].get(par_name)

if pipeline_value is None:
if default_value != pipeline_value:
Expand Down
40 changes: 1 addition & 39 deletions src/spikeinterface/preprocessing/preprocessinglist.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,45 +46,6 @@

from .motion import correct_motion

pp_name_to_function = {
# filter stuff
"filter": filter,
"bandpass_filter": bandpass_filter,
"notch_filter": notch_filter,
"highpass_filter": highpass_filter,
"gaussian_filter": gaussian_filter,
# gain offset stuff
"normalize_by_quantile": normalize_by_quantile,
"scale": scale,
"zscore": zscore,
"center": center,
# decorrelation stuff
"whiten": whiten,
# re-reference
"common_reference": common_reference,
"phase_shift": phase_shift,
# misc
"rectify": rectify,
"clip": clip,
"blank_staturation": blank_staturation,
"silence_periods": silence_periods,
"remove_artifacts": remove_artifacts,
"zero_channel_pad": zero_channel_pad,
"deepinterpolate": deepinterpolate,
"resample": resample,
"decimate": decimate,
"highpass_spatial_filter": highpass_spatial_filter,
"interpolate_bad_channels": interpolate_bad_channels,
"depth_order": depth_order,
"average_across_direction": average_across_direction,
"directional_derivative": directional_derivative,
"astype": astype,
"unsigned_to_signed": unsigned_to_signed,
"unsigned_to_signed": unsigned_to_signed,
# motion correction
"correct_motion": correct_motion,
}

pp_function_to_class = {
# filter stuff
"filter": FilterRecording,
Expand Down Expand Up @@ -119,6 +80,7 @@
"directional_derivative": DirectionalDerivativeRecording,
"astype": AstypeRecording,
"unsigned_to_signed": UnsignedToSignedRecording,
"correct_motion": correct_motion,
}


Expand Down

0 comments on commit 5c19765

Please sign in to comment.