Skip to content

Commit

Permalink
Fix versioning
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Jan 27, 2024
1 parent e566e3e commit 26fdd0a
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 78 deletions.
79 changes: 44 additions & 35 deletions notebooks/py_scripts/10_Spike_SortingV0.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
dj.config.load("dj_local_conf.json") # load config for database connection info

import spyglass.common as sgc
import spyglass.spikesorting as sgs
import spyglass.spikesorting.v0 as sgs

# ignore datajoint+jupyter async warnings
import warnings
Expand All @@ -84,15 +84,24 @@
# If you haven't already done so, add yourself to `LabTeam`
#

name, email, dj_user = "Firstname Lastname", "[email protected]", "user"
# Full name, Google email address, DataJoint username, admin
name, email, dj_user, admin = (
"Firstname Lastname",
"[email protected]",
"user",
0,
)
sgc.LabMember.insert_from_name(name)
sgc.LabMember.LabMemberInfo.insert1(
[name, email, dj_user], skip_duplicates=True
)
sgc.LabTeam.LabTeamMember.insert1(
{"team_name": "My Team", "lab_member_name": name},
[
name,
email,
dj_user,
admin,
],
skip_duplicates=True,
)
sgc.LabMember.LabMemberInfo()

# We can try `fetch` to confirm.
#
Expand Down Expand Up @@ -151,7 +160,7 @@
# _Note:_ This will delete any existing entries. Answer 'yes' when prompted.
#

sgs.v0.SortGroup().set_group_by_shank(nwb_file_name)
sgs.SortGroup().set_group_by_shank(nwb_file_name)

# Each electrode has an `electrode_id` and is associated with an
# `electrode_group_name`, which corresponds with a `sort_group_id`.
Expand All @@ -161,7 +170,7 @@
# 32 unique `sort_group_id`.
#

sgs.v0.SortGroup.SortGroupElectrode & {"nwb_file_name": nwb_file_name}
sgs.SortGroup.SortGroupElectrode & {"nwb_file_name": nwb_file_name}

# #### `IntervalList`
#
Expand Down Expand Up @@ -203,7 +212,7 @@ def print_interval_duration(interval_list: np.ndarray):
# With the above, we can insert into `SortInterval`
#

sgs.v0.SortInterval.insert1(
sgs.SortInterval.insert1(
{
"nwb_file_name": nwb_file_name,
"sort_interval_name": sort_interval_name,
Expand All @@ -217,7 +226,7 @@ def print_interval_duration(interval_list: np.ndarray):

print_interval_duration(
(
sgs.v0.SortInterval
sgs.SortInterval
& {
"nwb_file_name": nwb_file_name,
"sort_interval_name": sort_interval_name,
Expand All @@ -232,14 +241,14 @@ def print_interval_duration(interval_list: np.ndarray):
# recorded data in the spike band prior to sorting.
#

sgs.v0.SpikeSortingPreprocessingParameters()
sgs.SpikeSortingPreprocessingParameters()

# Here, we insert the default parameters and then fetch them.
#

sgs.v0.SpikeSortingPreprocessingParameters().insert_default()
sgs.SpikeSortingPreprocessingParameters().insert_default()
preproc_params = (
sgs.v0.SpikeSortingPreprocessingParameters()
sgs.SpikeSortingPreprocessingParameters()
& {"preproc_params_name": "default"}
).fetch1("preproc_params")
print(preproc_params)
Expand All @@ -249,7 +258,7 @@ def print_interval_duration(interval_list: np.ndarray):
#

preproc_params["frequency_min"] = 600
sgs.v0.SpikeSortingPreprocessingParameters().insert1(
sgs.SpikeSortingPreprocessingParameters().insert1(
{
"preproc_params_name": "default_hippocampus",
"preproc_params": preproc_params,
Expand Down Expand Up @@ -281,8 +290,8 @@ def print_interval_duration(interval_list: np.ndarray):
# time/tetrode/etc. of the recording we want to extract.
#

sgs.v0.SpikeSortingRecordingSelection.insert1(ssr_key, skip_duplicates=True)
sgs.v0.SpikeSortingRecordingSelection() & ssr_key
sgs.SpikeSortingRecordingSelection.insert1(ssr_key, skip_duplicates=True)
sgs.SpikeSortingRecordingSelection() & ssr_key

# ### `SpikeSortingRecording`
#
Expand All @@ -296,13 +305,13 @@ def print_interval_duration(interval_list: np.ndarray):
# and use a list of primary keys when calling `populate`.
#

ssr_pk = (sgs.v0.SpikeSortingRecordingSelection & ssr_key).proj()
sgs.v0.SpikeSortingRecording.populate([ssr_pk])
ssr_pk = (sgs.SpikeSortingRecordingSelection & ssr_key).proj()
sgs.SpikeSortingRecording.populate([ssr_pk])

# Now we can see our recording in the table. _E x c i t i n g !_
#

sgs.v0.SpikeSortingRecording() & ssr_key
sgs.SpikeSortingRecording() & ssr_key

# ## Artifact Detection
#
Expand All @@ -314,28 +323,28 @@ def print_interval_duration(interval_list: np.ndarray):
# For this demo, we'll use a parameter set to skip this step.
#

sgs.v0.ArtifactDetectionParameters().insert_default()
artifact_key = (sgs.v0.SpikeSortingRecording() & ssr_key).fetch1("KEY")
sgs.ArtifactDetectionParameters().insert_default()
artifact_key = (sgs.SpikeSortingRecording() & ssr_key).fetch1("KEY")
artifact_key["artifact_params_name"] = "none"

# We then pair artifact detection parameters in `ArtifactParameters` with a
# recording extracted through population of `SpikeSortingRecording` and insert
# into `ArtifactDetectionSelection`.
#

sgs.v0.ArtifactDetectionSelection().insert1(artifact_key)
sgs.v0.ArtifactDetectionSelection() & artifact_key
sgs.ArtifactDetectionSelection().insert1(artifact_key)
sgs.ArtifactDetectionSelection() & artifact_key

# Then, we can populate `ArtifactDetection`, which will find periods where there
# are artifacts, as specified by the parameters.
#

sgs.v0.ArtifactDetection.populate(artifact_key)
sgs.ArtifactDetection.populate(artifact_key)

# Populating `ArtifactDetection` also inserts an entry into `ArtifactRemovedIntervalList`, which stores the interval without detected artifacts.
#

sgs.v0.ArtifactRemovedIntervalList() & artifact_key
sgs.ArtifactRemovedIntervalList() & artifact_key

# ## Spike sorting
#
Expand All @@ -346,12 +355,12 @@ def print_interval_duration(interval_list: np.ndarray):
#

# +
sgs.v0.SpikeSorterParameters().insert_default()
sgs.SpikeSorterParameters().insert_default()

# Let's look at the default params
sorter_name = "mountainsort4"
ms4_default_params = (
sgs.v0.SpikeSorterParameters
sgs.SpikeSorterParameters
& {"sorter": sorter_name, "sorter_params_name": "default"}
).fetch1()
print(ms4_default_params)
Expand Down Expand Up @@ -385,7 +394,7 @@ def print_interval_duration(interval_list: np.ndarray):
#

sorter_params_name = "hippocampus_tutorial"
sgs.v0.SpikeSorterParameters.insert1(
sgs.SpikeSorterParameters.insert1(
{
"sorter": sorter_name,
"sorter_params_name": sorter_params_name,
Expand All @@ -394,7 +403,7 @@ def print_interval_duration(interval_list: np.ndarray):
skip_duplicates=True,
)
(
sgs.v0.SpikeSorterParameters
sgs.SpikeSorterParameters
& {"sorter": sorter_name, "sorter_params_name": sorter_params_name}
).fetch1()

Expand All @@ -409,16 +418,16 @@ def print_interval_duration(interval_list: np.ndarray):
#

ss_key = dict(
**(sgs.v0.ArtifactDetection & ssr_key).fetch1("KEY"),
**(sgs.v0.ArtifactRemovedIntervalList() & ssr_key).fetch1("KEY"),
**(sgs.ArtifactDetection & ssr_key).fetch1("KEY"),
**(sgs.ArtifactRemovedIntervalList() & ssr_key).fetch1("KEY"),
sorter=sorter_name,
sorter_params_name=sorter_params_name,
)
ss_key.pop("artifact_params_name")
ss_key

sgs.v0.SpikeSortingSelection.insert1(ss_key, skip_duplicates=True)
(sgs.v0.SpikeSortingSelection & ss_key)
sgs.SpikeSortingSelection.insert1(ss_key, skip_duplicates=True)
(sgs.SpikeSortingSelection & ss_key)

# ### `SpikeSorting`
#
Expand All @@ -428,12 +437,12 @@ def print_interval_duration(interval_list: np.ndarray):
#

# [(sgs.SpikeSortingSelection & ss_key).proj()]
sgs.v0.SpikeSorting.populate()
sgs.SpikeSorting.populate()

# #### Check to make sure the table populated
#

sgs.v0.SpikeSorting() & ss_key
sgs.SpikeSorting() & ss_key

# ## Next Steps
#
Expand Down
Loading

0 comments on commit 26fdd0a

Please sign in to comment.