Skip to content

Commit

Permalink
Minor decoding fixes (#769)
Browse files Browse the repository at this point in the history
* Add non-local detector and remove replay_trajectory_classification

* Reorganize

* Fix formatting and imports

* Update .gitignore

* Remove because of circular import

* Fix name of parameter

* Handle case where ther is only one interval

* Fix settings

* Handle single interval

* from_unit_dict does not exist in 0.98.2 of spike interface

* Simplify call

* Update for SpikeSorting merge table and add spyglass mixin

* Fix dependencies

* Fix merge conflict

* Update src/spyglass/decoding/v1/clusterless.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Update src/spyglass/decoding/v1/clusterless.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Update src/spyglass/decoding/v1/clusterless.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Update src/spyglass/decoding/v1/clusterless.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Apply suggestions from code review

Co-authored-by: Chris Brozdowski <[email protected]>

* Remove unused imports and format

* Add saving of waveform features

* Don't store electrodes, full waveforms, waveform mean

* Fix spike times and add convenience method

* Add spike location and some formatting

* Remove circular import

* Fix dict expansion

* Initial working clusterless pipeline

* Add position group

* Rename classifier to decoding

* Handle encoding and decoding intervals

* Put old files under v0, try/except for old decoding package

* Rename visualization and remove from v0

v0 visualization is redundant with visualization

* Place parameters and position group in core.py

* Add sorted spikes decoding

* Add objects to init for convenience

* Remove unused imports

* Fix fetching of spike times

* Insert into merge table

* Update CHANGELOG.md

* Function for removing decoding outputs not in DecodingOutput

* Fix name

* Add draft of tutorials and rearrange notebooks

* Fix config loading

* Add 1D decoding and some notes on estimate_parameters kwarg

* Update 43_Decoding_SortedSpikes.ipynb

* Remove old decoding notebook

* Save initial conditions and discrete transitions

* Apply suggestions from code review

Co-authored-by: Chris Brozdowski <[email protected]>

* Be more specific with import error

* Remove unneeded comments

* Remove incorrect dimension name

* Project merge_id from SpikeSortingOutput for clarity

* Update src/spyglass/decoding/v0/clusterless.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Update src/spyglass/decoding/v0/clusterless.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Update src/spyglass/decoding/v0/clusterless.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Fix linting

* Update notebooks

* Ignore .pem

* Add session as a primary key for Groups

* Add some helper methods

* Update notebooks

* Update README.md

* Update pyscripts

* Update 42_Decoding_Clusterless.ipynb

* Update CHANGELOG.md

* Add fetch and insert

* Simplify class conversion

* Do the dictionary conversion of class for the user

* Update CHANGELOG.md

* Update .gitignore

* Use methods in populate

* Avoid fetching interval range if not needed

* Generalize finding class from modules

* Use args/kwargs

* Simplify tuple unpacking

* Make decoding kwargs nullable

* Add function for get_recording and get_sorting to the spikesorting merge table

* make decoding waveform features agnostic to spikesorting source

* Fix spelling

* Use fetch1_dataframe for position

* Use self instead of class

* Update src/spyglass/decoding/v1/sorted_spikes.py

Co-authored-by: Samuel Bray <[email protected]>

* Be more careful about populating select keys

* Make more readable/remove unused imports

* Save classifier

* Clean up saved model paths

* add function load_linear_position_info

* Update src/spyglass/decoding/v1/sorted_spikes.py

Co-authored-by: Samuel Bray <[email protected]>

* Update 41_Extracting_Clusterless_Waveform_Features.py

* Update docstring

* Apply suggestions from code review

Co-authored-by: Chris Brozdowski <[email protected]>

* Update src/spyglass/decoding/v1/clusterless.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Update src/spyglass/decoding/v1/clusterless.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Fix linting

* Fix syntax

* Rename variable to avoid confusion

* Restrict UnitWaveformFeaturesGroup and SortedSpikesGroup

* Concatenate linear position and position dataframes

* Static methods don't require instantiating class

* Avoid merge restrict

* Add version to defaults

* Remove unused import

* Fix classifier path

* Add dry run

* Remove non-default

* Handle permissions and file not found

* Keep position info within encoding/decoding interval

* Add methods to get the spike_times, spike_indicators, firing rate

* Fix docstring to match default

* Implement function rather than import

* Remove unused broken imports

* Add decoding cleanup

* Fix import

* Put old vis code back

* Fix import

* Add draft helper functions

* Limit options on input

* Fix logic

* Fix where the key is passed

* Update notebooks

* Host main visualizations in non_local_detector repo

* Update notebooks/py_scripts/41_Extracting_Clusterless_Waveform_Features.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Update src/spyglass/spikesorting/merge.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Update src/spyglass/decoding/decoding_merge.py

Co-authored-by: Chris Brozdowski <[email protected]>

* Revert "Limit options on input"

This reverts commit 386714c.

* Use f-string for version

* Add useful imports to the top level

This would have to change a bit if there were multiple versions of the pipeline.

* Make source class a hidden attribute

* Update CHANGELOG.md

---------

Co-authored-by: Chris Brozdowski <[email protected]>
Co-authored-by: Sam Bray <[email protected]>
  • Loading branch information
3 people authored Jan 19, 2024
1 parent 0089d5e commit ad78ea1
Show file tree
Hide file tree
Showing 18 changed files with 1,089 additions and 959 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
- Refactor input validation in DLC pipeline. #688
- DLC path handling from config, and normalize naming convention. #722
- Decoding:
- Add `decoding` pipeline V1. #731
- Add `decoding` pipeline V1. #731, #769
- Add a table to store the decoding results #731
- Use the new `non_local_detector` package for decoding #731
- Allow multiple spike waveform features for clusterelss decoding #731
Expand Down
10 changes: 2 additions & 8 deletions franklab_scripts/nightly_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@
# ignore datajoint+jupyter async warnings
import warnings

import numpy as np

from spyglass.decoding.clusterless import (
MarkParameters,
UnitMarkParameters,
UnitMarks,
)

warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.simplefilter("ignore", category=ResourceWarning)
# NOTE: "SPIKE_SORTING_STORAGE_DIR" -> "SPYGLASS_SORTING_DIR"
Expand All @@ -21,12 +13,14 @@

# import tables so that we can call them easily
from spyglass.common import AnalysisNwbfile
from spyglass.decoding.decoding_merge import DecodingOutput
from spyglass.spikesorting import SpikeSorting


def main():
AnalysisNwbfile().nightly_cleanup()
SpikeSorting().nightly_cleanup()
DecodingOutput().cleanup()


if __name__ == "__main__":
Expand Down
643 changes: 352 additions & 291 deletions notebooks/41_Extracting_Clusterless_Waveform_Features.ipynb

Large diffs are not rendered by default.

597 changes: 232 additions & 365 deletions notebooks/42_Decoding_Clusterless.ipynb

Large diffs are not rendered by default.

331 changes: 175 additions & 156 deletions notebooks/43_Decoding_SortedSpikes.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/py_scripts/30_LFP.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.15.2
# jupytext_version: 1.16.0
# kernelspec:
# display_name: Python 3.10.5 64-bit
# language: python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
# The goal of this notebook is to populate the `UnitWaveformFeatures` table, which depends `SpikeSortingOutput`. This table contains the features of the waveforms of each unit.
#
# While clusterless decoding avoids actual spike sorting, we need to pass through these tables to maintain (relative) pipeline simplicity. Pass-through tables keep spike sorting and clusterless waveform extraction as similar as possible, by using shared steps. Here, "spike sorting" involves simple thresholding (sorter: clusterless_thresholder).
#
# Let's start with the following nwb file and time interval:

# +
from pathlib import Path
Expand All @@ -41,13 +39,39 @@
dj.config.load(
Path("../dj_local_conf.json").absolute()
) # load config for database connection info
# -

# First, if you haven't inserted the the `mediumnwb20230802.nwb` file into the database (see [01_Data_Insert](01_Data_Insert.ipynb)), you should do so now. This is the file that we will use for the decoding tutorials.
#
# It is a truncated version of the full NWB file, so it will run faster, but bigger than the minirec file we used in the previous tutorials so that decoding makes sense.

# +
from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename
import spyglass.data_import as sgi
import spyglass.position as sgp

# Insert the nwb file
nwb_file_name = "mediumnwb20230802.nwb"
nwb_copy_file_name = get_nwb_copy_filename(nwb_file_name)
sgi.insert_sessions(nwb_file_name)

# Position
sgp.v1.TrodesPosParams.insert_default()

nwb_copy_file_name = "mediumnwb20230802_.nwb"
interval_list_name = "pos 0 valid times"

trodes_s_key = {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": interval_list_name,
"trodes_pos_params_name": "default",
}
sgp.v1.TrodesPosSelection.insert1(
trodes_s_key,
skip_duplicates=True,
)
sgp.v1.TrodesPosV1.populate(trodes_s_key)
# -

# If you haven't already, run the [Insert Data notebook](./01_Insert_Data.ipynb) to populate the tables.
#
# These next steps are the same as in the [Spike Sorting notebook](./10_Spike_SortingV1.ipynb), but we'll repeat them here for clarity. These are pre-processing steps that are shared between spike sorting and clusterless decoding.
#
# We first set the `SortGroup` to define which contacts are sorted together.
Expand Down
78 changes: 28 additions & 50 deletions notebooks/py_scripts/42_Decoding_Clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@
PositionOutput.TrodesPosV1 & {"nwb_file_name": nwb_copy_file_name}

# +
from spyglass.decoding.v1.clusterless import PositionGroup
from spyglass.decoding.v1.core import PositionGroup

position_merge_ids = (
PositionOutput.TrodesPosV1
Expand Down Expand Up @@ -342,33 +342,33 @@
#

# +
from non_local_detector.visualization import (
create_interactive_2D_decoding_figurl,
)

(
position_info,
position_variable_names,
) = ClusterlessDecodingV1.load_position_info(selection_key)
results_time = decoding_results.acausal_posterior.isel(intervals=0).time.values
position_info = position_info.loc[results_time[0] : results_time[-1]]

env = ClusterlessDecodingV1.load_environments(selection_key)[0]
spike_times, _ = ClusterlessDecodingV1.load_spike_data(selection_key)


create_interactive_2D_decoding_figurl(
position_time=position_info.index.to_numpy(),
position=position_info[position_variable_names],
env=env,
results=decoding_results,
posterior=decoding_results.acausal_posterior.isel(intervals=0)
.unstack("state_bins")
.sum("state"),
spike_times=spike_times,
head_dir=position_info["orientation"],
speed=position_info["speed"],
)
# from non_local_detector.visualization import (
# create_interactive_2D_decoding_figurl,
# )

# (
# position_info,
# position_variable_names,
# ) = ClusterlessDecodingV1.load_position_info(selection_key)
# results_time = decoding_results.acausal_posterior.isel(intervals=0).time.values
# position_info = position_info.loc[results_time[0] : results_time[-1]]

# env = ClusterlessDecodingV1.load_environments(selection_key)[0]
# spike_times, _ = ClusterlessDecodingV1.load_spike_data(selection_key)


# create_interactive_2D_decoding_figurl(
# position_time=position_info.index.to_numpy(),
# position=position_info[position_variable_names],
# env=env,
# results=decoding_results,
# posterior=decoding_results.acausal_posterior.isel(intervals=0)
# .unstack("state_bins")
# .sum("state"),
# spike_times=spike_times,
# head_dir=position_info["orientation"],
# speed=position_info["speed"],
# )
# -

# ## GPUs
Expand Down Expand Up @@ -411,25 +411,3 @@
# to monitor GPU usage in the notebook
# - A [terminal program](https://github.com/peci1/nvidia-htop) like nvidia-smi
# with more information about which GPUs are being utilized and by whom.
#
# ### Parallelizing Decoding
#
# You can also use the [dask_cuda](https://docs.rapids.ai/api/dask-cuda/nightly/) to parallelize decoding. You will need to install the `dask_cuda` package (see [here](https://docs.rapids.ai/api/dask-cuda/nightly/install/)). You then can run the following code to parallelize decoding:

# +
# import dask
# from dask.distributed import Client
# from dask_cuda import LocalCUDACluster

# cluster = LocalCUDACluster()

# selection_keys = [] # list of selection keys

# with Client(cluster) as client:
# results = [
# dask.delayed(ClusterlessDecodingV1.populate)(
# selection_key, reserve_jobs=True
# )
# for selection_key in selection_keys
# ]
# dask.compute(*results)
23 changes: 18 additions & 5 deletions src/spyglass/decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
from spyglass.decoding.decoding_merge import DecodingOutput # noqa: E402
from spyglass.decoding.visualization.core import ( # noqa: E402
create_interactive_1D_decoding_figurl,
create_interactive_2D_decoding_figurl,
make_multi_environment_movie,
make_single_environment_movie,
from spyglass.decoding.v1.clusterless import ( # noqa: E402
ClusterlessDecodingSelection,
ClusterlessDecodingV1,
UnitWaveformFeaturesGroup,
)
from spyglass.decoding.v1.core import (
DecodingParameters,
PositionGroup,
) # noqa: E402
from spyglass.decoding.v1.sorted_spikes import ( # noqa: E402
SortedSpikesDecodingSelection,
SortedSpikesDecodingV1,
SortedSpikesGroup,
)
from spyglass.decoding.v1.waveform_features import ( # noqa: E402
UnitWaveformFeatures,
UnitWaveformFeaturesSelection,
WaveformFeaturesParams,
)
109 changes: 105 additions & 4 deletions src/spyglass/decoding/decoding_merge.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import inspect
from itertools import chain
from pathlib import Path

import datajoint as dj
import numpy as np
from datajoint.utils import to_camel_case
from non_local_detector.visualization.figurl_1D import create_1D_decode_view
from non_local_detector.visualization.figurl_2D import create_2D_decode_view

from spyglass.decoding.v1.clusterless import ClusterlessDecodingV1 # noqa: F401
from spyglass.decoding.v1.sorted_spikes import (
Expand Down Expand Up @@ -35,9 +40,12 @@ class SortedSpikesDecodingV1(SpyglassMixin, dj.Part): # noqa: F811
-> SortedSpikesDecodingV1
"""

def cleanup(self):
def cleanup(self, dry_run=False):
"""Remove any decoding outputs that are not in the merge table"""
logger.info("Cleaning up decoding outputs")
if dry_run:
logger.info("Dry run, not removing any files")
else:
logger.info("Cleaning up decoding outputs")
table_results_paths = list(
chain(
*[
Expand All @@ -51,7 +59,11 @@ def cleanup(self):
for path in Path(config["SPYGLASS_ANALYSIS_DIR"]).glob("**/*.nc"):
if str(path) not in table_results_paths:
logger.info(f"Removing {path}")
path.unlink()
if not dry_run:
try:
path.unlink(missing_ok=True) # Ignore FileNotFoundError
except PermissionError:
logger.warning(f"Unable to remove {path}, skipping")

table_model_paths = list(
chain(
Expand All @@ -66,4 +78,93 @@ def cleanup(self):
for path in Path(config["SPYGLASS_ANALYSIS_DIR"]).glob("**/*.pkl"):
if str(path) not in table_model_paths:
logger.info(f"Removing {path}")
path.unlink()
if not dry_run:
try:
path.unlink()
except (PermissionError, FileNotFoundError):
logger.warning(f"Unable to remove {path}, skipping")

@classmethod
def _get_source_class(cls, key):
if cls._source_class_dict is None:
cls._source_class_dict = {}
module = inspect.getmodule(cls)
for part_name in cls.parts():
part_name = to_camel_case(part_name.split("__")[-1].strip("`"))
part = getattr(module, part_name)
cls._source_class_dict[part_name] = part

source = (cls & key).fetch1("source")
return cls._source_class_dict[source]

@classmethod
def load_results(cls, key):
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
source_class = cls._get_source_class(key)
return (source_class & decoding_selection_key).load_results()

@classmethod
def load_model(cls, key):
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
source_class = cls._get_source_class(key)
return (source_class & decoding_selection_key).load_model()

@classmethod
def load_environments(cls, key):
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
source_class = cls._get_source_class(key)
return source_class.load_environments(decoding_selection_key)

@classmethod
def load_position_info(cls, key):
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
source_class = cls._get_source_class(key)
return source_class.load_position_info(decoding_selection_key)

@classmethod
def load_linear_position_info(cls, key):
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
source_class = cls._get_source_class(key)
return source_class.load_linear_position_info(decoding_selection_key)

@classmethod
def load_spike_data(cls, key, filter_by_interval=True):
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
source_class = cls._get_source_class(key)
return source_class.load_linear_position_info(
decoding_selection_key, filter_by_interval=filter_by_interval
)

@classmethod
def create_decoding_view(cls, key, head_direction_name="head_orientation"):
results = cls.load_results(key)
posterior = results.acausal_posterior.unstack("state_bins").sum("state")
env = cls.load_environments(key)[0]

if "x_position" in results.coords:
position_info, position_variable_names = cls.load_position_info(key)
# Not 1D
bin_size = (
np.nanmedian(np.diff(np.unique(results.x_position.values))),
np.nanmedian(np.diff(np.unique(results.y_position.values))),
)
return create_2D_decode_view(
position_time=position_info.index,
position=position_info[position_variable_names],
interior_place_bin_centers=env.place_bin_centers_[
env.is_track_interior_.ravel(order="C")
],
place_bin_size=bin_size,
posterior=posterior,
head_dir=position_info[head_direction_name],
)
else:
(
position_info,
position_variable_names,
) = cls.load_linear_position_info(key)
return create_1D_decode_view(
posterior=posterior,
linear_position=position_info["linear_position"],
ref_time_sec=position_info.index[0],
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from ripple_detection import get_multiunit_population_firing_rate
from tqdm.auto import tqdm

from spyglass.decoding.visualization.view1D import create_1D_decode_view
from spyglass.decoding.visualization.view2D import create_2D_decode_view
from spyglass.decoding.v0.visualization_1D_view import create_1D_decode_view
from spyglass.decoding.v0.visualization_2D_view import create_2D_decode_view
from spyglass.utils import logger


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import numpy as np
import sortingview.views.franklab as vvf
import xarray as xr
from replay_trajectory_classification.environments import (
get_grid,
get_track_interior,
)


def create_static_track_animation(
Expand Down
Loading

0 comments on commit ad78ea1

Please sign in to comment.