Skip to content

Commit

Permalink
Restrict UnitWaveformFeaturesGroup and SortedSpikesGroup by nwb_file_…
Browse files Browse the repository at this point in the history
…name (#758)

* Restrict UnitWaveformFeaturesGroup and SortedSpikesGroup

* Concatenate linear position and position dataframes

* Static methods don't require instantiating class
  • Loading branch information
edeno authored Jan 11, 2024
1 parent 2251b34 commit 06f29b5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
16 changes: 8 additions & 8 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,9 @@ def load_position_info(key):

@staticmethod
def load_linear_position_info(key):
environment = ClusterlessDecodingV1().load_environments(key)[0]
environment = ClusterlessDecodingV1.load_environments(key)[0]

position_df = ClusterlessDecodingV1().load_position_info(key)[0]
position_df = ClusterlessDecodingV1.load_position_info(key)[0]
position = np.asarray(position_df[["position_x", "position_y"]])

linear_position_df = get_linearized_position(
Expand All @@ -338,11 +338,10 @@ def load_linear_position_info(key):
edge_spacing=environment.edge_spacing,
)

linear_position_df.insert(4, "speed", np.asarray(position_df.speed))

linear_position_df.insert(5, "time", np.asarray(position_df.index))
linear_position_df.set_index("time", inplace=True)
return linear_position_df
return pd.concat(
[linear_position_df.set_index(position_df.index), position_df],
axis=1,
)

@staticmethod
def _get_interval_range(key):
Expand Down Expand Up @@ -379,9 +378,10 @@ def load_spike_data(key, filter_by_interval=True):
(
UnitWaveformFeaturesGroup.UnitFeatures
& {
"nwb_file_name": key["nwb_file_name"],
"waveform_features_group_name": key[
"waveform_features_group_name"
]
],
}
)
).fetch("KEY")
Expand Down
19 changes: 10 additions & 9 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ def load_position_info(key):

@staticmethod
def load_linear_position_info(key):
environment = SortedSpikesDecodingV1().load_environments(key)[0]
environment = SortedSpikesDecodingV1.load_environments(key)[0]

position_df = SortedSpikesDecodingV1().load_position_info(key)[0]
position_df = SortedSpikesDecodingV1.load_position_info(key)[0]
position = np.asarray(position_df[["position_x", "position_y"]])

linear_position_df = get_linearized_position(
Expand All @@ -328,12 +328,10 @@ def load_linear_position_info(key):
edge_order=environment.edge_order,
edge_spacing=environment.edge_spacing,
)

linear_position_df.insert(4, "speed", np.asarray(position_df.speed))

linear_position_df.insert(5, "time", np.asarray(position_df.index))
linear_position_df.set_index("time", inplace=True)
return linear_position_df
return pd.concat(
[linear_position_df.set_index(position_df.index), position_df],
axis=1,
)

@staticmethod
def _get_interval_range(key):
Expand Down Expand Up @@ -369,7 +367,10 @@ def load_spike_data(key, filter_by_interval=True):
merge_ids = (
(
SortedSpikesGroup.SortGroup
& {"sorted_spikes_group_name": key["sorted_spikes_group_name"]}
& {
"nwb_file_name": key["nwb_file_name"],
"sorted_spikes_group_name": key["sorted_spikes_group_name"],
}
)
).fetch("spikesorting_merge_id")

Expand Down

0 comments on commit 06f29b5

Please sign in to comment.