diff --git a/docs/README.md b/docs/README.md index 80510daed..8eee3f1a4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -55,3 +55,6 @@ The following items can be commented out in `mkdocs.yml` to reduce build time: - `mkdocs-jupyter`: Generates tutorial pages from notebooks. To end the process in your console, use `ctrl+c`. + +If your new submodule is causing a build error (e.g., "Could not collect ..."), +you may need to add `__init__.py` files to the submodule directories. diff --git a/docs/build-docs.sh b/docs/build-docs.sh index 03d28c07e..b36b0533d 100755 --- a/docs/build-docs.sh +++ b/docs/build-docs.sh @@ -10,13 +10,14 @@ cp ./LICENSE ./docs/src/LICENSE.md mkdir -p ./docs/src/notebooks cp ./notebooks/*ipynb ./docs/src/notebooks/ cp ./notebooks/*md ./docs/src/notebooks/ -cp ./docs/src/notebooks/README.md ./docs/src/notebooks/index.md +mv ./docs/src/notebooks/README.md ./docs/src/notebooks/index.md cp -r ./notebook-images ./docs/src/notebooks/ cp -r ./notebook-images ./docs/src/ # Get major version -FULL_VERSION=$(hatch version) # Most recent tag, may include periods -export MAJOR_VERSION="${FULL_VERSION:0:3}" # First 3 chars of tag +version_line=$(grep "__version__ =" ./src/spyglass/_version.py) +version_string=$(echo "$version_line" | awk -F"[\"']" '{print $2}') +export MAJOR_VERSION="${version_string:0:3}" echo "$MAJOR_VERSION" # Get ahead of errors diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 32f789bb9..4def830b4 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -16,9 +16,6 @@ theme: favicon: images/Spyglass.svg features: - toc.follow - # - navigation.expand # CBroz1: removed bc long tutorial list hides rest - # - toc.integrate - # - navigation.sections - navigation.top - navigation.instant # saves loading time - 1 browser page - navigation.tracking # even with above, changes URL by section @@ -55,27 +52,30 @@ nav: - Database Management: misc/database_management.md - Tutorials: - Overview: notebooks/index.md - - General: + - Intro: - Setup: notebooks/00_Setup.ipynb - Insert Data: notebooks/01_Insert_Data.ipynb - Data Sync: notebooks/02_Data_Sync.ipynb - Merge Tables: notebooks/03_Merge_Tables.ipynb - - Ephys: - - Spike Sorting: notebooks/10_Spike_Sorting.ipynb + - Config Populate: notebooks/04_PopulateConfigFile.ipynb + - Spikes: + - Spike Sorting V0: notebooks/10_Spike_SortingV0.ipynb + - Spike Sorting V1: notebooks/10_Spike_SortingV1.ipynb - Curation: notebooks/11_Curation.ipynb - - LFP: notebooks/12_LFP.ipynb - - Theta: notebooks/14_Theta.ipynb - Position: - Position Trodes: notebooks/20_Position_Trodes.ipynb - DLC From Scratch: notebooks/21_Position_DLC_1.ipynb - DLC From Model: notebooks/22_Position_DLC_2.ipynb - DLC Prediction: notebooks/23_Position_DLC_3.ipynb - Linearization: notebooks/24_Linearization.ipynb - - Combined: - - Ripple Detection: notebooks/30_Ripple_Detection.ipynb - - Extract Mark Indicators: notebooks/31_Extract_Mark_Indicators.ipynb - - Decoding with GPUs: notebooks/32_Decoding_with_GPUs.ipynb - - Decoding Clusterless: notebooks/33_Decoding_Clusterless.ipynb + - LFP: + - LFP: notebooks/30_LFP.ipynb + - Theta: notebooks/31_Theta.ipynb + - Ripple Detection: notebooks/32_Ripple_Detection.ipynb + - Decoding: + - Extract Clusterless: notebooks/41_Extracting_Clusterless_Waveform_Features.ipynb + - Decoding Clusterless: notebooks/42_Decoding_Clusterless.ipynb + - Decoding Sorted Spikes: notebooks/43_Decoding_SortedSpikes.ipynb - API Reference: api/ # defer to gen-files + literate-nav - How to Contribute: contribute.md - Change Log: CHANGELOG.md diff --git a/docs/src/api/make_pages.py b/docs/src/api/make_pages.py index 942f6ae09..6886d50f4 100644 --- a/docs/src/api/make_pages.py +++ b/docs/src/api/make_pages.py @@ -28,11 +28,6 @@ else: break -if add_limit is not None: - from IPython import embed - - embed() - with mkdocs_gen_files.open("api/navigation.md", "w") as nav_file: nav_file.write("* [Overview](../api/index.md)\n") diff --git a/notebooks/README.md b/notebooks/README.md index ab18707cd..33df01ed8 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -8,32 +8,64 @@ described in the categories below. ## 0. Intro -Everyone should complete the [Setup](./00_Setup.ipynb) and [Insert Data](./01_Insert_Data.ipynb) notebooks. +Everyone should complete the [Setup](./00_Setup.ipynb) and +[Insert Data](./01_Insert_Data.ipynb) notebooks. -[Data Sync](./02_Data_Sync.ipynb) is an optional additional tool for collaborators that want to share analysis files. +[Data Sync](./02_Data_Sync.ipynb) is an optional additional tool for +collaborators that want to share analysis files. -The [Merge Tables notebook](./03_Merge_Tables.ipynb) explains details on a new table tier unique to Spyglass that allows the user to use different versions of pipelines on the same data. This is important for understanding the later notebooks. +The [Merge Tables notebook](./03_Merge_Tables.ipynb) explains details on a new +table tier unique to Spyglass that allows the user to use different versions of +pipelines on the same data. This is important for understanding the later +notebooks. ## 1. Spike Sorting Pipeline -This series of notebooks covers the process of spike sorting, from automated spike sorting to optional manual curation of the output of the automated sorting. +This series of notebooks covers the process of spike sorting, from automated +spike sorting to optional manual curation of the output of the automated +sorting. ## 2. Position Pipeline -This series of notebooks covers tracking the position(s) of the animal. The user can employ two different methods: +This series of notebooks covers tracking the position(s) of the animal. The +user can employ two different methods: -1. the simple [Trodes](20_Position_Trodes.ipynb) methods of tracking LEDs on the animal's headstage -2. [DLC (DeepLabCut)](./21_Position_DLC_1.ipynb) which uses a neural network to track the animal's body parts +1. the simple [Trodes](20_Position_Trodes.ipynb) methods of tracking LEDs on + the animal's headstage +2. [DLC (DeepLabCut)](./21_Position_DLC_1.ipynb) which uses a neural network to + track the animal's body parts -Either case can be followed by the [Linearization notebook](./24_Linearization.ipynb) if the user wants to linearize the position data for later use. +Either case can be followed by the +[Linearization notebook](./24_Linearization.ipynb) if the user wants to +linearize the position data for later use. ## 3. LFP Pipeline -This series of notebooks covers the process of LFP analysis. The [LFP](./30_LFP.ipynb) covers the extraction of the LFP in specific bands from the raw data. The [Theta](./31_Theta.ipynb) notebook shows specifically how to extract the theta band power and phase from the LFP data. Finally the [Ripple Detection](./32_Ripple_Detection.ipynb) notebook shows how to detect ripples in the LFP data. +This series of notebooks covers the process of LFP analysis. The +[LFP](./30_LFP.ipynb) covers the extraction of the LFP in specific bands from +the raw data. The [Theta](./31_Theta.ipynb) notebook shows specifically how to +extract the theta band power and phase from the LFP data. Finally the +[Ripple Detection](./32_Ripple_Detection.ipynb) notebook shows how to detect +ripples in the LFP data. ## 4. Decoding Pipeline -This series of notebooks covers the process of decoding the position of the animal from spiking data. It relies on the position data from the Position pipeline and the output of spike sorting from the Spike Sorting pipeline. Decoding can be from sorted or from unsorted data using spike waveform features (so-called clusterless decoding). The first notebook([Extracting Clusterless Waveform Features](./41_Extracting_Clusterless_Waveform_Features.ipynb)) in this series shows how to retrieve the spike waveform features used for clusterless decoding. The second notebook ([Clusterless Decoding](./42_Decoding_Clusterless.ipynb)) shows a detailed example of how to decode the position of the animal from the spike waveform features. The third notebook ([Decoding](./43_Decoding.ipynb)) shows how to decode the position of the animal from the sorted spikes. +This series of notebooks covers the process of decoding the position of the +animal from spiking data. It relies on the position data from the Position +pipeline and the output of spike sorting from the Spike Sorting pipeline. +Decoding can be from sorted or from unsorted data using spike waveform features +(so-called clusterless decoding). + +The first notebook +([Extracting Clusterless Waveform Features](./41_Extracting_Clusterless_Waveform_Features.ipynb)) +in this series shows how to retrieve the spike waveform features used for +clusterless decoding. + +The second notebook +([Clusterless Decoding](./42_Decoding_Clusterless.ipynb)) shows a detailed +example of how to decode the position of the animal from the spike waveform +features. The third notebook ([Decoding](./43_Decoding.ipynb)) shows how to +decode the position of the animal from the sorted spikes. ## Developer note diff --git a/src/spyglass/common/common_lab.py b/src/spyglass/common/common_lab.py index bdaa0fb25..177fc4424 100644 --- a/src/spyglass/common/common_lab.py +++ b/src/spyglass/common/common_lab.py @@ -108,7 +108,7 @@ def get_djuser_name(cls, dj_user) -> str: Parameters ---------- - user: str + dj_user: str The datajoint user name. Returns diff --git a/src/spyglass/decoding/v0/__init__.py b/src/spyglass/decoding/v0/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/spyglass/decoding/v0/clusterless.py b/src/spyglass/decoding/v0/clusterless.py index 1bd385c20..f6fd9df37 100644 --- a/src/spyglass/decoding/v0/clusterless.py +++ b/src/spyglass/decoding/v0/clusterless.py @@ -70,15 +70,14 @@ @schema class MarkParameters(SpyglassMixin, dj.Manual): - """Defines the type of spike waveform feature computed for a given spike - time.""" + """Defines the type of waveform feature computed for a given spike time.""" definition = """ mark_param_name : varchar(32) # a name for this set of parameters --- # the type of mark. Currently only 'amplitude' is supported mark_type = 'amplitude': varchar(40) - mark_param_dict: BLOB # dictionary of parameters for the mark extraction function + mark_param_dict: BLOB # dict of parameters for the mark extraction function """ # NOTE: See #630, #664. Excessive key length. @@ -99,7 +98,8 @@ def insert_default(self): @staticmethod def supported_mark_type(mark_type): - """checks whether the requested mark type is supported. + """Checks whether the requested mark type is supported. + Currently only 'amplitude" is supported. Parameters @@ -108,9 +108,7 @@ def supported_mark_type(mark_type): """ supported_types = ["amplitude"] - if mark_type in supported_types: - return True - return False + return mark_type in supported_types @schema @@ -123,7 +121,9 @@ class UnitMarkParameters(SpyglassMixin, dj.Manual): @schema class UnitMarks(SpyglassMixin, dj.Computed): - """For each spike time, compute a spike waveform feature associated with that + """Compute spike waveform features for each spike time. + + For each spike time, compute a spike waveform feature associated with that spike. Used for clusterless decoding. """ @@ -224,15 +224,16 @@ def make(self, key): AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) self.insert1(key) - def fetch1_dataframe(self): + def fetch1_dataframe(self) -> pd.DataFrame: """Convenience function for returning the marks in a readable format""" return self.fetch_dataframe()[0] - def fetch_dataframe(self): + def fetch_dataframe(self) -> list[pd.DataFrame]: return [self._convert_to_dataframe(data) for data in self.fetch_nwb()] @staticmethod - def _convert_to_dataframe(nwb_data): + def _convert_to_dataframe(nwb_data) -> pd.DataFrame: + """Converts the marks from an NWB object to a pandas dataframe""" n_marks = nwb_data["marks"].data.shape[1] columns = [f"amplitude_{ind:04d}" for ind in range(n_marks)] return pd.DataFrame( @@ -243,23 +244,28 @@ def _convert_to_dataframe(nwb_data): @staticmethod def _get_peak_amplitude( - waveform, peak_sign="neg", estimate_peak_time=False - ): - """Returns the amplitudes of all channels at the time of the peak - amplitude across channels. + waveform: np.array, + peak_sign: str = "neg", + estimate_peak_time: bool = False, + ) -> np.array: + """Returns the amplitudes of all channels at the time of the peak. + + Amplitude across channels. Parameters ---------- - waveform : array-like, shape (n_spikes, n_time, n_channels) - peak_sign : ('pos', 'neg', 'both'), optional - Direction of the peak in the waveform + waveform : np.array + array-like, shape (n_spikes, n_time, n_channels) + peak_sign : str, optional + One of 'pos', 'neg', 'both'. Direction of the peak in the waveform estimate_peak_time : bool, optional Find the peak times for each spike because some spikesorters do not align the spike time (at index n_time // 2) to the peak Returns ------- - peak_amplitudes : array-like, shape (n_spikes, n_channels) + peak_amplitudes : np.array + array-like, shape (n_spikes, n_channels) """ if estimate_peak_time: @@ -279,19 +285,25 @@ def _get_peak_amplitude( return waveform[:, spike_peak_ind] @staticmethod - def _threshold(timestamps, marks, mark_param_dict): + def _threshold( + timestamps: np.array, marks: np.array, mark_param_dict: dict + ): """Filter the marks by an amplitude threshold Parameters ---------- - timestamps : array-like, shape (n_time,) - marks : array-like, shape (n_time, n_channels) + timestamps : np.array + array-like, shape (n_time,) + marks : np.array + array-like, shape (n_time, n_channels) mark_param_dict : dict Returns ------- - filtered_timestamps : array-like, shape (n_filtered_time,) - filtered_marks : array-like, shape (n_filtered_time, n_channels) + filtered_timestamps : np.array + array-like, shape (n_filtered_time,) + filtered_marks : np.array + array-like, shape (n_filtered_time, n_channels) """ if mark_param_dict["peak_sign"] == "neg": @@ -307,20 +319,24 @@ def _threshold(timestamps, marks, mark_param_dict): @schema class UnitMarksIndicatorSelection(SpyglassMixin, dj.Lookup): - """Bins the spike times and associated spike waveform features for a given - time interval into regular time bins determined by the sampling rate.""" + """Pairing of a UnitMarksIndicator with a time interval and sampling rate + + Bins the spike times and associated spike waveform features for a given + time interval into regular time bins determined by the sampling rate. + """ definition = """ -> UnitMarks -> IntervalList sampling_rate=500 : float - --- """ @schema class UnitMarksIndicator(SpyglassMixin, dj.Computed): - """Bins the spike times and associated spike waveform features into regular + """Bins spike times and waveforms into regular time bins. + + Bins the spike times and associated spike waveform features into regular time bins according to the sampling rate. Features that fall into the same time bin are averaged. """ @@ -373,7 +389,9 @@ def make(self, key): self.insert1(key) @staticmethod - def get_time_bins_from_interval(interval_times, sampling_rate): + def get_time_bins_from_interval( + interval_times: np.array, sampling_rate: int + ) -> np.array: """Picks the superset of the interval""" start_time, end_time = interval_times[0][0], interval_times[-1][-1] n_samples = int(np.ceil((end_time - start_time) * sampling_rate)) + 1 @@ -382,9 +400,14 @@ def get_time_bins_from_interval(interval_times, sampling_rate): @staticmethod def plot_all_marks( - marks_indicators: xr.DataArray, plot_size=5, s=10, plot_limit=None + marks_indicators: xr.DataArray, + plot_size: int = 5, + marker_size: int = 10, + plot_limit: int = None, ): - """Plots 2D slices of each of the spike features against each other + """Plot all marks for all electrodes. + + Plots 2D slices of each of the spike features against each other for all electrodes. Parameters @@ -393,7 +416,7 @@ def plot_all_marks( Spike times and associated spike waveform features binned into plot_size : int, optional Default 5. Matplotlib figure size for each mark. - s : int, optional + marker_size : int, optional Default 10. Marker size plot_limit : int, optional Default None. Limits to first N electrodes. @@ -422,25 +445,28 @@ def plot_all_marks( axes[ax_ind1, ax_ind2].scatter( marks.sel(marks=feature1), marks.sel(marks=feature2), - s=s, + s=marker_size, ) except TypeError: axes.scatter( marks.sel(marks=feature1), marks.sel(marks=feature2), - s=s, + s=marker_size, ) - def fetch1_dataframe(self): + def fetch1_dataframe(self) -> pd.DataFrame: + """Convenience function for returning the first dataframe""" return self.fetch_dataframe()[0] - def fetch_dataframe(self): + def fetch_dataframe(self) -> list[pd.DataFrame]: + """Fetches the marks indicators as a list of pandas dataframes""" return [ data["marks_indicator"].set_index("time") for data in self.fetch_nwb() ] def fetch_xarray(self): + """Fetches the marks indicators as an xarray DataArray""" # sort_group_electrodes = ( # SortGroup.SortGroupElectrode() & # pd.DataFrame(self).to_dict('records')) @@ -474,7 +500,16 @@ def reformat_name(name): ) -def make_default_decoding_parameters_cpu(): +def make_default_decoding_parameters_cpu() -> tuple[dict, dict, dict]: + """Default parameters for decoding on CPU + + Returns + ------- + classifier_parameters : dict + fit_parameters : dict + predict_parameters : dict + """ + classifier_parameters = dict( environments=[_DEFAULT_ENVIRONMENT], observation_models=None, @@ -496,7 +531,16 @@ def make_default_decoding_parameters_cpu(): return classifier_parameters, fit_parameters, predict_parameters -def make_default_decoding_parameters_gpu(): +def make_default_decoding_parameters_gpu() -> tuple[dict, dict, dict]: + """Default parameters for decoding on GPU + + Returns + ------- + classifier_parameters : dict + fit_parameters : dict + predict_parameters : dict + """ + classifier_parameters = dict( environments=[_DEFAULT_ENVIRONMENT], observation_models=None, @@ -524,7 +568,9 @@ def make_default_decoding_parameters_gpu(): @schema class ClusterlessClassifierParameters(SpyglassMixin, dj.Manual): - """Decodes the animal's mental position and some category of interest + """Decodes animal's mental position. + + Decodes the animal's mental position and some category of interest from unclustered spikes and spike waveform features """ @@ -536,7 +582,8 @@ class ClusterlessClassifierParameters(SpyglassMixin, dj.Manual): predict_params : BLOB # prediction parameters """ - def insert_default(self): + def insert_default(self) -> None: + """Insert the default parameter set""" ( classifier_parameters, fit_parameters, @@ -567,10 +614,12 @@ def insert_default(self): skip_duplicates=True, ) - def insert1(self, key, **kwargs): + def insert1(self, key, **kwargs) -> None: + """Custom insert1 to convert classes to dicts""" super().insert1(convert_classes_to_dict(key), **kwargs) - def fetch1(self, *args, **kwargs): + def fetch1(self, *args, **kwargs) -> dict: + """Custom fetch1 to convert dicts to classes""" return restore_classes(super().fetch1(*args, **kwargs)) @@ -619,10 +668,12 @@ def make(self, key): self.insert1(key) - def fetch1_dataframe(self): + def fetch1_dataframe(self) -> pd.DataFrame: + """Convenience function for returning the first dataframe""" return self.fetch_dataframe()[0] - def fetch_dataframe(self): + def fetch_dataframe(self) -> list[pd.DataFrame]: + """Fetches the multiunit firing rate as a list of pandas dataframes""" return [ data["multiunit_firing_rate"].set_index("time") for data in self.fetch_nwb() @@ -631,7 +682,7 @@ def fetch_dataframe(self): @schema class MultiunitHighSynchronyEventsParameters(SpyglassMixin, dj.Manual): - """Parameters for extracting times of high mulitunit activity during immobility.""" + """Params to extract times of high mulitunit activity during immobility.""" definition = """ param_name : varchar(80) # a name for this set of parameters @@ -642,6 +693,7 @@ class MultiunitHighSynchronyEventsParameters(SpyglassMixin, dj.Manual): """ def insert_default(self): + """Insert the default parameter set""" self.insert1( { "param_name": "default", @@ -673,7 +725,6 @@ def get_decoding_data_for_epoch( position_info : pd.DataFrame, shape (n_time, n_columns) marks : xr.DataArray, shape (n_time, n_marks, n_electrodes) valid_slices : list[slice] - """ valid_ephys_position_times_by_epoch = ( @@ -744,7 +795,6 @@ def get_data_for_multiple_epochs( marks : xr.DataArray, shape (n_time, n_marks, n_electrodes) valid_slices : dict[str, list[slice]] environment_labels : np.ndarray, shape (n_time,) - """ data = [] environment_labels = [] @@ -780,7 +830,9 @@ def populate_mark_indicators( mark_param_name: str = "default", position_info_param_name: str = "default_decoding", ): - """Populate mark indicators for all units in the given spike sorting selection. + """Populate mark indicators + + Populates for all units in a given spike sorting selection. This function is a way to do several pipeline steps at once. It will: 1. Populate the SpikeSortingSelection table diff --git a/src/spyglass/decoding/v1/__init__.py b/src/spyglass/decoding/v1/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/spyglass/decoding/v1/waveform_features.py b/src/spyglass/decoding/v1/waveform_features.py index 5302c80dd..4bed99f35 100644 --- a/src/spyglass/decoding/v1/waveform_features.py +++ b/src/spyglass/decoding/v1/waveform_features.py @@ -69,8 +69,7 @@ def check_supported_waveform_features(waveform_features: list[str]) -> bool: Parameters ---------- - mark_type : str - + waveform_features : list """ supported_features = set(WAVEFORM_FEATURE_FUNCTIONS) return set(waveform_features).issubset(supported_features) diff --git a/src/spyglass/decoding/visualization/__init__.py b/src/spyglass/decoding/visualization/__init__.py new file mode 100644 index 000000000..e69de29bb