-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
87 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,7 +69,7 @@ | |
dj.config.load("dj_local_conf.json") # load config for database connection info | ||
|
||
import spyglass.common as sgc | ||
import spyglass.spikesorting as sgs | ||
import spyglass.spikesorting.v0 as sgs | ||
|
||
# ignore datajoint+jupyter async warnings | ||
import warnings | ||
|
@@ -84,15 +84,24 @@ | |
# If you haven't already done so, add yourself to `LabTeam` | ||
# | ||
|
||
name, email, dj_user = "Firstname Lastname", "[email protected]", "user" | ||
# Full name, Google email address, DataJoint username, admin | ||
name, email, dj_user, admin = ( | ||
"Firstname Lastname", | ||
"[email protected]", | ||
"user", | ||
0, | ||
) | ||
sgc.LabMember.insert_from_name(name) | ||
sgc.LabMember.LabMemberInfo.insert1( | ||
[name, email, dj_user], skip_duplicates=True | ||
) | ||
sgc.LabTeam.LabTeamMember.insert1( | ||
{"team_name": "My Team", "lab_member_name": name}, | ||
[ | ||
name, | ||
email, | ||
dj_user, | ||
admin, | ||
], | ||
skip_duplicates=True, | ||
) | ||
sgc.LabMember.LabMemberInfo() | ||
|
||
# We can try `fetch` to confirm. | ||
# | ||
|
@@ -151,7 +160,7 @@ | |
# _Note:_ This will delete any existing entries. Answer 'yes' when prompted. | ||
# | ||
|
||
sgs.v0.SortGroup().set_group_by_shank(nwb_file_name) | ||
sgs.SortGroup().set_group_by_shank(nwb_file_name) | ||
|
||
# Each electrode has an `electrode_id` and is associated with an | ||
# `electrode_group_name`, which corresponds with a `sort_group_id`. | ||
|
@@ -161,7 +170,7 @@ | |
# 32 unique `sort_group_id`. | ||
# | ||
|
||
sgs.v0.SortGroup.SortGroupElectrode & {"nwb_file_name": nwb_file_name} | ||
sgs.SortGroup.SortGroupElectrode & {"nwb_file_name": nwb_file_name} | ||
|
||
# #### `IntervalList` | ||
# | ||
|
@@ -203,7 +212,7 @@ def print_interval_duration(interval_list: np.ndarray): | |
# With the above, we can insert into `SortInterval` | ||
# | ||
|
||
sgs.v0.SortInterval.insert1( | ||
sgs.SortInterval.insert1( | ||
{ | ||
"nwb_file_name": nwb_file_name, | ||
"sort_interval_name": sort_interval_name, | ||
|
@@ -217,7 +226,7 @@ def print_interval_duration(interval_list: np.ndarray): | |
|
||
print_interval_duration( | ||
( | ||
sgs.v0.SortInterval | ||
sgs.SortInterval | ||
& { | ||
"nwb_file_name": nwb_file_name, | ||
"sort_interval_name": sort_interval_name, | ||
|
@@ -232,14 +241,14 @@ def print_interval_duration(interval_list: np.ndarray): | |
# recorded data in the spike band prior to sorting. | ||
# | ||
|
||
sgs.v0.SpikeSortingPreprocessingParameters() | ||
sgs.SpikeSortingPreprocessingParameters() | ||
|
||
# Here, we insert the default parameters and then fetch them. | ||
# | ||
|
||
sgs.v0.SpikeSortingPreprocessingParameters().insert_default() | ||
sgs.SpikeSortingPreprocessingParameters().insert_default() | ||
preproc_params = ( | ||
sgs.v0.SpikeSortingPreprocessingParameters() | ||
sgs.SpikeSortingPreprocessingParameters() | ||
& {"preproc_params_name": "default"} | ||
).fetch1("preproc_params") | ||
print(preproc_params) | ||
|
@@ -249,7 +258,7 @@ def print_interval_duration(interval_list: np.ndarray): | |
# | ||
|
||
preproc_params["frequency_min"] = 600 | ||
sgs.v0.SpikeSortingPreprocessingParameters().insert1( | ||
sgs.SpikeSortingPreprocessingParameters().insert1( | ||
{ | ||
"preproc_params_name": "default_hippocampus", | ||
"preproc_params": preproc_params, | ||
|
@@ -281,8 +290,8 @@ def print_interval_duration(interval_list: np.ndarray): | |
# time/tetrode/etc. of the recording we want to extract. | ||
# | ||
|
||
sgs.v0.SpikeSortingRecordingSelection.insert1(ssr_key, skip_duplicates=True) | ||
sgs.v0.SpikeSortingRecordingSelection() & ssr_key | ||
sgs.SpikeSortingRecordingSelection.insert1(ssr_key, skip_duplicates=True) | ||
sgs.SpikeSortingRecordingSelection() & ssr_key | ||
|
||
# ### `SpikeSortingRecording` | ||
# | ||
|
@@ -296,13 +305,13 @@ def print_interval_duration(interval_list: np.ndarray): | |
# and use a list of primary keys when calling `populate`. | ||
# | ||
|
||
ssr_pk = (sgs.v0.SpikeSortingRecordingSelection & ssr_key).proj() | ||
sgs.v0.SpikeSortingRecording.populate([ssr_pk]) | ||
ssr_pk = (sgs.SpikeSortingRecordingSelection & ssr_key).proj() | ||
sgs.SpikeSortingRecording.populate([ssr_pk]) | ||
|
||
# Now we can see our recording in the table. _E x c i t i n g !_ | ||
# | ||
|
||
sgs.v0.SpikeSortingRecording() & ssr_key | ||
sgs.SpikeSortingRecording() & ssr_key | ||
|
||
# ## Artifact Detection | ||
# | ||
|
@@ -314,28 +323,28 @@ def print_interval_duration(interval_list: np.ndarray): | |
# For this demo, we'll use a parameter set to skip this step. | ||
# | ||
|
||
sgs.v0.ArtifactDetectionParameters().insert_default() | ||
artifact_key = (sgs.v0.SpikeSortingRecording() & ssr_key).fetch1("KEY") | ||
sgs.ArtifactDetectionParameters().insert_default() | ||
artifact_key = (sgs.SpikeSortingRecording() & ssr_key).fetch1("KEY") | ||
artifact_key["artifact_params_name"] = "none" | ||
|
||
# We then pair artifact detection parameters in `ArtifactParameters` with a | ||
# recording extracted through population of `SpikeSortingRecording` and insert | ||
# into `ArtifactDetectionSelection`. | ||
# | ||
|
||
sgs.v0.ArtifactDetectionSelection().insert1(artifact_key) | ||
sgs.v0.ArtifactDetectionSelection() & artifact_key | ||
sgs.ArtifactDetectionSelection().insert1(artifact_key) | ||
sgs.ArtifactDetectionSelection() & artifact_key | ||
|
||
# Then, we can populate `ArtifactDetection`, which will find periods where there | ||
# are artifacts, as specified by the parameters. | ||
# | ||
|
||
sgs.v0.ArtifactDetection.populate(artifact_key) | ||
sgs.ArtifactDetection.populate(artifact_key) | ||
|
||
# Populating `ArtifactDetection` also inserts an entry into `ArtifactRemovedIntervalList`, which stores the interval without detected artifacts. | ||
# | ||
|
||
sgs.v0.ArtifactRemovedIntervalList() & artifact_key | ||
sgs.ArtifactRemovedIntervalList() & artifact_key | ||
|
||
# ## Spike sorting | ||
# | ||
|
@@ -346,12 +355,12 @@ def print_interval_duration(interval_list: np.ndarray): | |
# | ||
|
||
# + | ||
sgs.v0.SpikeSorterParameters().insert_default() | ||
sgs.SpikeSorterParameters().insert_default() | ||
|
||
# Let's look at the default params | ||
sorter_name = "mountainsort4" | ||
ms4_default_params = ( | ||
sgs.v0.SpikeSorterParameters | ||
sgs.SpikeSorterParameters | ||
& {"sorter": sorter_name, "sorter_params_name": "default"} | ||
).fetch1() | ||
print(ms4_default_params) | ||
|
@@ -385,7 +394,7 @@ def print_interval_duration(interval_list: np.ndarray): | |
# | ||
|
||
sorter_params_name = "hippocampus_tutorial" | ||
sgs.v0.SpikeSorterParameters.insert1( | ||
sgs.SpikeSorterParameters.insert1( | ||
{ | ||
"sorter": sorter_name, | ||
"sorter_params_name": sorter_params_name, | ||
|
@@ -394,7 +403,7 @@ def print_interval_duration(interval_list: np.ndarray): | |
skip_duplicates=True, | ||
) | ||
( | ||
sgs.v0.SpikeSorterParameters | ||
sgs.SpikeSorterParameters | ||
& {"sorter": sorter_name, "sorter_params_name": sorter_params_name} | ||
).fetch1() | ||
|
||
|
@@ -409,16 +418,16 @@ def print_interval_duration(interval_list: np.ndarray): | |
# | ||
|
||
ss_key = dict( | ||
**(sgs.v0.ArtifactDetection & ssr_key).fetch1("KEY"), | ||
**(sgs.v0.ArtifactRemovedIntervalList() & ssr_key).fetch1("KEY"), | ||
**(sgs.ArtifactDetection & ssr_key).fetch1("KEY"), | ||
**(sgs.ArtifactRemovedIntervalList() & ssr_key).fetch1("KEY"), | ||
sorter=sorter_name, | ||
sorter_params_name=sorter_params_name, | ||
) | ||
ss_key.pop("artifact_params_name") | ||
ss_key | ||
|
||
sgs.v0.SpikeSortingSelection.insert1(ss_key, skip_duplicates=True) | ||
(sgs.v0.SpikeSortingSelection & ss_key) | ||
sgs.SpikeSortingSelection.insert1(ss_key, skip_duplicates=True) | ||
(sgs.SpikeSortingSelection & ss_key) | ||
|
||
# ### `SpikeSorting` | ||
# | ||
|
@@ -428,12 +437,12 @@ def print_interval_duration(interval_list: np.ndarray): | |
# | ||
|
||
# [(sgs.SpikeSortingSelection & ss_key).proj()] | ||
sgs.v0.SpikeSorting.populate() | ||
sgs.SpikeSorting.populate() | ||
|
||
# #### Check to make sure the table populated | ||
# | ||
|
||
sgs.v0.SpikeSorting() & ss_key | ||
sgs.SpikeSorting() & ss_key | ||
|
||
# ## Next Steps | ||
# | ||
|
Oops, something went wrong.