From 26fdd0ac2d83b4a838d8e63d2f5bd04b81712c1b Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 27 Jan 2024 15:26:29 -0800 Subject: [PATCH] Fix versioning --- notebooks/py_scripts/10_Spike_SortingV0.py | 79 +++++++++++--------- notebooks/py_scripts/10_Spike_SortingV1.py | 86 +++++++++++----------- 2 files changed, 87 insertions(+), 78 deletions(-) diff --git a/notebooks/py_scripts/10_Spike_SortingV0.py b/notebooks/py_scripts/10_Spike_SortingV0.py index e9bc82f09..6daa28200 100644 --- a/notebooks/py_scripts/10_Spike_SortingV0.py +++ b/notebooks/py_scripts/10_Spike_SortingV0.py @@ -69,7 +69,7 @@ dj.config.load("dj_local_conf.json") # load config for database connection info import spyglass.common as sgc -import spyglass.spikesorting as sgs +import spyglass.spikesorting.v0 as sgs # ignore datajoint+jupyter async warnings import warnings @@ -84,15 +84,24 @@ # If you haven't already done so, add yourself to `LabTeam` # -name, email, dj_user = "Firstname Lastname", "example@gmail.com", "user" +# Full name, Google email address, DataJoint username, admin +name, email, dj_user, admin = ( + "Firstname Lastname", + "example@gmail.com", + "user", + 0, +) sgc.LabMember.insert_from_name(name) sgc.LabMember.LabMemberInfo.insert1( - [name, email, dj_user], skip_duplicates=True -) -sgc.LabTeam.LabTeamMember.insert1( - {"team_name": "My Team", "lab_member_name": name}, + [ + name, + email, + dj_user, + admin, + ], skip_duplicates=True, ) +sgc.LabMember.LabMemberInfo() # We can try `fetch` to confirm. # @@ -151,7 +160,7 @@ # _Note:_ This will delete any existing entries. Answer 'yes' when prompted. # -sgs.v0.SortGroup().set_group_by_shank(nwb_file_name) +sgs.SortGroup().set_group_by_shank(nwb_file_name) # Each electrode has an `electrode_id` and is associated with an # `electrode_group_name`, which corresponds with a `sort_group_id`. @@ -161,7 +170,7 @@ # 32 unique `sort_group_id`. # -sgs.v0.SortGroup.SortGroupElectrode & {"nwb_file_name": nwb_file_name} +sgs.SortGroup.SortGroupElectrode & {"nwb_file_name": nwb_file_name} # #### `IntervalList` # @@ -203,7 +212,7 @@ def print_interval_duration(interval_list: np.ndarray): # With the above, we can insert into `SortInterval` # -sgs.v0.SortInterval.insert1( +sgs.SortInterval.insert1( { "nwb_file_name": nwb_file_name, "sort_interval_name": sort_interval_name, @@ -217,7 +226,7 @@ def print_interval_duration(interval_list: np.ndarray): print_interval_duration( ( - sgs.v0.SortInterval + sgs.SortInterval & { "nwb_file_name": nwb_file_name, "sort_interval_name": sort_interval_name, @@ -232,14 +241,14 @@ def print_interval_duration(interval_list: np.ndarray): # recorded data in the spike band prior to sorting. # -sgs.v0.SpikeSortingPreprocessingParameters() +sgs.SpikeSortingPreprocessingParameters() # Here, we insert the default parameters and then fetch them. # -sgs.v0.SpikeSortingPreprocessingParameters().insert_default() +sgs.SpikeSortingPreprocessingParameters().insert_default() preproc_params = ( - sgs.v0.SpikeSortingPreprocessingParameters() + sgs.SpikeSortingPreprocessingParameters() & {"preproc_params_name": "default"} ).fetch1("preproc_params") print(preproc_params) @@ -249,7 +258,7 @@ def print_interval_duration(interval_list: np.ndarray): # preproc_params["frequency_min"] = 600 -sgs.v0.SpikeSortingPreprocessingParameters().insert1( +sgs.SpikeSortingPreprocessingParameters().insert1( { "preproc_params_name": "default_hippocampus", "preproc_params": preproc_params, @@ -281,8 +290,8 @@ def print_interval_duration(interval_list: np.ndarray): # time/tetrode/etc. of the recording we want to extract. # -sgs.v0.SpikeSortingRecordingSelection.insert1(ssr_key, skip_duplicates=True) -sgs.v0.SpikeSortingRecordingSelection() & ssr_key +sgs.SpikeSortingRecordingSelection.insert1(ssr_key, skip_duplicates=True) +sgs.SpikeSortingRecordingSelection() & ssr_key # ### `SpikeSortingRecording` # @@ -296,13 +305,13 @@ def print_interval_duration(interval_list: np.ndarray): # and use a list of primary keys when calling `populate`. # -ssr_pk = (sgs.v0.SpikeSortingRecordingSelection & ssr_key).proj() -sgs.v0.SpikeSortingRecording.populate([ssr_pk]) +ssr_pk = (sgs.SpikeSortingRecordingSelection & ssr_key).proj() +sgs.SpikeSortingRecording.populate([ssr_pk]) # Now we can see our recording in the table. _E x c i t i n g !_ # -sgs.v0.SpikeSortingRecording() & ssr_key +sgs.SpikeSortingRecording() & ssr_key # ## Artifact Detection # @@ -314,8 +323,8 @@ def print_interval_duration(interval_list: np.ndarray): # For this demo, we'll use a parameter set to skip this step. # -sgs.v0.ArtifactDetectionParameters().insert_default() -artifact_key = (sgs.v0.SpikeSortingRecording() & ssr_key).fetch1("KEY") +sgs.ArtifactDetectionParameters().insert_default() +artifact_key = (sgs.SpikeSortingRecording() & ssr_key).fetch1("KEY") artifact_key["artifact_params_name"] = "none" # We then pair artifact detection parameters in `ArtifactParameters` with a @@ -323,19 +332,19 @@ def print_interval_duration(interval_list: np.ndarray): # into `ArtifactDetectionSelection`. # -sgs.v0.ArtifactDetectionSelection().insert1(artifact_key) -sgs.v0.ArtifactDetectionSelection() & artifact_key +sgs.ArtifactDetectionSelection().insert1(artifact_key) +sgs.ArtifactDetectionSelection() & artifact_key # Then, we can populate `ArtifactDetection`, which will find periods where there # are artifacts, as specified by the parameters. # -sgs.v0.ArtifactDetection.populate(artifact_key) +sgs.ArtifactDetection.populate(artifact_key) # Populating `ArtifactDetection` also inserts an entry into `ArtifactRemovedIntervalList`, which stores the interval without detected artifacts. # -sgs.v0.ArtifactRemovedIntervalList() & artifact_key +sgs.ArtifactRemovedIntervalList() & artifact_key # ## Spike sorting # @@ -346,12 +355,12 @@ def print_interval_duration(interval_list: np.ndarray): # # + -sgs.v0.SpikeSorterParameters().insert_default() +sgs.SpikeSorterParameters().insert_default() # Let's look at the default params sorter_name = "mountainsort4" ms4_default_params = ( - sgs.v0.SpikeSorterParameters + sgs.SpikeSorterParameters & {"sorter": sorter_name, "sorter_params_name": "default"} ).fetch1() print(ms4_default_params) @@ -385,7 +394,7 @@ def print_interval_duration(interval_list: np.ndarray): # sorter_params_name = "hippocampus_tutorial" -sgs.v0.SpikeSorterParameters.insert1( +sgs.SpikeSorterParameters.insert1( { "sorter": sorter_name, "sorter_params_name": sorter_params_name, @@ -394,7 +403,7 @@ def print_interval_duration(interval_list: np.ndarray): skip_duplicates=True, ) ( - sgs.v0.SpikeSorterParameters + sgs.SpikeSorterParameters & {"sorter": sorter_name, "sorter_params_name": sorter_params_name} ).fetch1() @@ -409,16 +418,16 @@ def print_interval_duration(interval_list: np.ndarray): # ss_key = dict( - **(sgs.v0.ArtifactDetection & ssr_key).fetch1("KEY"), - **(sgs.v0.ArtifactRemovedIntervalList() & ssr_key).fetch1("KEY"), + **(sgs.ArtifactDetection & ssr_key).fetch1("KEY"), + **(sgs.ArtifactRemovedIntervalList() & ssr_key).fetch1("KEY"), sorter=sorter_name, sorter_params_name=sorter_params_name, ) ss_key.pop("artifact_params_name") ss_key -sgs.v0.SpikeSortingSelection.insert1(ss_key, skip_duplicates=True) -(sgs.v0.SpikeSortingSelection & ss_key) +sgs.SpikeSortingSelection.insert1(ss_key, skip_duplicates=True) +(sgs.SpikeSortingSelection & ss_key) # ### `SpikeSorting` # @@ -428,12 +437,12 @@ def print_interval_duration(interval_list: np.ndarray): # # [(sgs.SpikeSortingSelection & ss_key).proj()] -sgs.v0.SpikeSorting.populate() +sgs.SpikeSorting.populate() # #### Check to make sure the table populated # -sgs.v0.SpikeSorting() & ss_key +sgs.SpikeSorting() & ss_key # ## Next Steps # diff --git a/notebooks/py_scripts/10_Spike_SortingV1.py b/notebooks/py_scripts/10_Spike_SortingV1.py index a6f37bae5..ae6c6fb82 100644 --- a/notebooks/py_scripts/10_Spike_SortingV1.py +++ b/notebooks/py_scripts/10_Spike_SortingV1.py @@ -34,7 +34,7 @@ # import spyglass.common as sgc -import spyglass.spikesorting as sgs +import spyglass.spikesorting.v1 as sgs import spyglass.data_import as sgi # insert LabMember and Session @@ -52,7 +52,7 @@ # insert SortGroup # -sgs.v1.SortGroup.set_group_by_shank(nwb_file_name=nwb_file_name2) +sgs.SortGroup.set_group_by_shank(nwb_file_name=nwb_file_name2) # insert SpikeSortingRecordingSelection. use `insert_selection` method. this automatically generates a unique recording id # @@ -65,34 +65,34 @@ "team_name": "Alison Comrie", } -sgs.v1.SpikeSortingRecordingSelection.insert_selection(key) +sgs.SpikeSortingRecordingSelection.insert_selection(key) -sgs.v1.SpikeSortingRecordingSelection() +sgs.SpikeSortingRecordingSelection() # preprocess recording (filtering and referencing) # -sgs.v1.SpikeSortingRecording.populate() +sgs.SpikeSortingRecording.populate() -sgs.v1.SpikeSortingRecording() +sgs.SpikeSortingRecording() key = ( - sgs.v1.SpikeSortingRecordingSelection & {"nwb_file_name": nwb_file_name2} + sgs.SpikeSortingRecordingSelection & {"nwb_file_name": nwb_file_name2} ).fetch1() # insert ArtifactDetectionSelection # -sgs.v1.ArtifactDetectionSelection.insert_selection( +sgs.ArtifactDetectionSelection.insert_selection( {"recording_id": key["recording_id"], "artifact_param_name": "default"} ) # detect artifact; note the output is stored in IntervalList # -sgs.v1.ArtifactDetection.populate() +sgs.ArtifactDetection.populate() -sgs.v1.ArtifactDetection() +sgs.ArtifactDetection() # insert SpikeSortingSelection. again use `insert_selection` method # @@ -110,37 +110,37 @@ ), } -sgs.v1.SpikeSortingSelection() +sgs.SpikeSortingSelection() -sgs.v1.SpikeSortingSelection.insert_selection(key) +sgs.SpikeSortingSelection.insert_selection(key) -sgs.v1.SpikeSortingSelection() +sgs.SpikeSortingSelection() # run spike sorting # -sgs.v1.SpikeSorting.populate() +sgs.SpikeSorting.populate() -sgs.v1.SpikeSorting() +sgs.SpikeSorting() # we have two main ways of curating spike sorting: by computing quality metrics and applying threshold; and manually applying curation labels. to do so, we first insert CurationV1. use `insert_curation` method. # -sgs.v1.CurationV1.insert_curation( +sgs.CurationV1.insert_curation( sorting_id=( sgs.SpikeSortingSelection & {"recording_id": key["recording_id"]} ).fetch1("sorting_id"), description="testing sort", ) -sgs.v1.CurationV1() +sgs.CurationV1() # we will first do an automatic curation based on quality metrics # key = { "sorting_id": ( - sgs.v1.SpikeSortingSelection & {"recording_id": key["recording_id"]} + sgs.SpikeSortingSelection & {"recording_id": key["recording_id"]} ).fetch1("sorting_id"), "curation_id": 0, "waveform_param_name": "default_not_whitened", @@ -148,32 +148,32 @@ "metric_curation_param_name": "default", } -sgs.v1.MetricCurationSelection.insert_selection(key) +sgs.MetricCurationSelection.insert_selection(key) -sgs.v1.MetricCurationSelection() +sgs.MetricCurationSelection() -sgs.v1.MetricCuration.populate() +sgs.MetricCuration.populate() -sgs.v1.MetricCuration() +sgs.MetricCuration() # to do another round of curation, fetch the relevant info and insert back into CurationV1 using `insert_curation` # key = { "metric_curation_id": ( - sgs.v1.MetricCurationSelection & {"sorting_id": key["sorting_id"]} + sgs.MetricCurationSelection & {"sorting_id": key["sorting_id"]} ).fetch1("metric_curation_id") } -labels = sgs.v1.MetricCuration.get_labels(key) +labels = sgs.MetricCuration.get_labels(key) -merge_groups = sgs.v1.MetricCuration.get_merge_groups(key) +merge_groups = sgs.MetricCuration.get_merge_groups(key) -metrics = sgs.v1.MetricCuration.get_metrics(key) +metrics = sgs.MetricCuration.get_metrics(key) -sgs.v1.CurationV1.insert_curation( +sgs.CurationV1.insert_curation( sorting_id=( - sgs.v1.MetricCurationSelection + sgs.MetricCurationSelection & {"metric_curation_id": key["metric_curation_id"]} ).fetch1("sorting_id"), parent_curation_id=0, @@ -183,12 +183,12 @@ description="after metric curation", ) -sgs.v1.CurationV1() +sgs.CurationV1() # next we will do manual curation. this is done with figurl. to incorporate info from other stages of processing (e.g. metrics) we have to store that with kachery cloud and get curation uri referring to it. it can be done with `generate_curation_uri`. # -curation_uri = sgs.v1.FigURLCurationSelection.generate_curation_uri( +curation_uri = sgs.FigURLCurationSelection.generate_curation_uri( { "sorting_id": ( sgs.MetricCurationSelection @@ -200,7 +200,7 @@ key = { "sorting_id": ( - sgs.v1.MetricCurationSelection + sgs.MetricCurationSelection & {"metric_curation_id": key["metric_curation_id"]} ).fetch1("sorting_id"), "curation_id": 1, @@ -208,15 +208,15 @@ "metrics_figurl": list(metrics.keys()), } -sgs.v1.FigURLCurationSelection() +sgs.FigURLCurationSelection() -sgs.v1.FigURLCurationSelection.insert_selection(key) +sgs.FigURLCurationSelection.insert_selection(key) -sgs.v1.FigURLCurationSelection() +sgs.FigURLCurationSelection() -sgs.v1.FigURLCuration.populate() +sgs.FigURLCuration.populate() -sgs.v1.FigURLCuration() +sgs.FigURLCuration() # or you can manually specify it if you already have a `curation.json` # @@ -234,20 +234,20 @@ } # - -sgs.v1.FigURLCurationSelection.insert_selection(key) +sgs.FigURLCurationSelection.insert_selection(key) -sgs.v1.FigURLCuration.populate() +sgs.FigURLCuration.populate() -sgs.v1.FigURLCuration() +sgs.FigURLCuration() # once you apply manual curation (curation labels and merge groups) you can store them as nwb by inserting another row in CurationV1. And then you can do more rounds of curation if you want. # -labels = sgs.v1.FigURLCuration.get_labels(gh_curation_uri) +labels = sgs.FigURLCuration.get_labels(gh_curation_uri) -merge_groups = sgs.v1.FigURLCuration.get_merge_groups(gh_curation_uri) +merge_groups = sgs.FigURLCuration.get_merge_groups(gh_curation_uri) -sgs.v1.CurationV1.insert_curation( +sgs.CurationV1.insert_curation( sorting_id=key["sorting_id"], parent_curation_id=1, labels=labels, @@ -256,7 +256,7 @@ description="after figurl curation", ) -sgs.v1.CurationV1() +sgs.CurationV1() # We now insert the curated spike sorting to a `Merge` table for feeding into downstream processing pipelines. #