Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blackify 24.1.1 #808

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ repos:
- tomli

- repo: https://github.com/ambv/black
rev: 23.11.0
rev: 24.1.1
hooks:
- id: black
language_version: python3.9
4 changes: 1 addition & 3 deletions config/dj_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def main(*args):
save_method = (
"local"
if filename == "dj_local_conf.json"
else "global"
if filename is None
else "custom"
else "global" if filename is None else "custom"
)

config.save_dj_config(
Expand Down
10 changes: 6 additions & 4 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,12 @@ def _get_column_names(rp, pos_id):
INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1)
n_pos_dims = rp.data.shape[1]
column_names = [
col # use existing columns if already numbered
if "1" in rp.description or "2" in rp.description
# else number them by id
else col + str(pos_id + INDEX_ADJUST)
(
col # use existing columns if already numbered
if "1" in rp.description or "2" in rp.description
# else number them by id
else col + str(pos_id + INDEX_ADJUST)
)
for col in rp.description.split(", ")
]
if len(column_names) != n_pos_dims:
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/common/common_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,9 @@ def __read_ndx_probe_data(
{
"probe_id": nwb_probe_obj.probe_type,
"probe_type": nwb_probe_obj.probe_type,
"contact_side_numbering": "True"
if nwb_probe_obj.contact_side_numbering
else "False",
"contact_side_numbering": (
"True" if nwb_probe_obj.contact_side_numbering else "False"
),
}
)
# go through the shanks and add each one to the Shank table
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/common/common_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,9 @@ def filter_data(
for ii, (start, stop) in enumerate(indices):
extracted_ts = timestamps[start:stop:decimation]

new_timestamps[
ts_offset : ts_offset + len(extracted_ts)
] = extracted_ts
new_timestamps[ts_offset : ts_offset + len(extracted_ts)] = (
extracted_ts
)
ts_offset += len(extracted_ts)

# finally ready to filter data!
Expand Down
8 changes: 5 additions & 3 deletions src/spyglass/common/common_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def insert_from_nwbfile(cls, nwbf, *, nwb_file_name):
for _, epoch_data in epochs.iterrows():
epoch_dict = {
"nwb_file_name": nwb_file_name,
"interval_list_name": epoch_data.tags[0]
if epoch_data.tags
else f"interval_{epoch_data[0]}",
"interval_list_name": (
epoch_data.tags[0]
if epoch_data.tags
else f"interval_{epoch_data[0]}"
),
"valid_times": np.asarray(
[[epoch_data.start_time, epoch_data.stop_time]]
),
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/common/common_lab.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Schema for institution, lab team/name/members. Session-independent."""

import datajoint as dj

from spyglass.utils import SpyglassMixin, logger
Expand Down
11 changes: 6 additions & 5 deletions src/spyglass/decoding/v0/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
[1] Denovellis, E. L. et al. Hippocampal replay of experience at real-world
speeds. eLife 10, e64505 (2021).
"""

import os
import shutil
import uuid
Expand Down Expand Up @@ -654,11 +655,11 @@ def make(self, key):
key["nwb_file_name"]
)

key[
"multiunit_firing_rate_object_id"
] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=multiunit_firing_rate.reset_index(),
key["multiunit_firing_rate_object_id"] = (
nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=multiunit_firing_rate.reset_index(),
)
)

nwb_analysis_file.add(
Expand Down
37 changes: 19 additions & 18 deletions src/spyglass/decoding/v0/dj_decoder_conversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Converts decoder classes into dictionaries and dictionaries into classes
so that datajoint can store them in tables."""

from spyglass.utils import logger

try:
Expand Down Expand Up @@ -116,17 +117,17 @@ def restore_classes(params: dict) -> dict:
_convert_env_dict(env_params)
for env_params in params["classifier_params"]["environments"]
]
params["classifier_params"][
"discrete_transition_type"
] = _convert_dict_to_class(
params["classifier_params"]["discrete_transition_type"],
discrete_state_transition_types,
params["classifier_params"]["discrete_transition_type"] = (
_convert_dict_to_class(
params["classifier_params"]["discrete_transition_type"],
discrete_state_transition_types,
)
)
params["classifier_params"][
"initial_conditions_type"
] = _convert_dict_to_class(
params["classifier_params"]["initial_conditions_type"],
initial_conditions_types,
params["classifier_params"]["initial_conditions_type"] = (
_convert_dict_to_class(
params["classifier_params"]["initial_conditions_type"],
initial_conditions_types,
)
)

if params["classifier_params"].get("observation_models"):
Expand Down Expand Up @@ -176,10 +177,10 @@ def convert_classes_to_dict(key: dict) -> dict:
key["classifier_params"]["environments"]
)
]
key["classifier_params"][
"continuous_transition_types"
] = _convert_transitions_to_dict(
key["classifier_params"]["continuous_transition_types"]
key["classifier_params"]["continuous_transition_types"] = (
_convert_transitions_to_dict(
key["classifier_params"]["continuous_transition_types"]
)
)
key["classifier_params"]["discrete_transition_type"] = _to_dict(
key["classifier_params"]["discrete_transition_type"]
Expand All @@ -194,10 +195,10 @@ def convert_classes_to_dict(key: dict) -> dict:
]

try:
key["classifier_params"][
"clusterless_algorithm_params"
] = _convert_algorithm_params(
key["classifier_params"]["clusterless_algorithm_params"]
key["classifier_params"]["clusterless_algorithm_params"] = (
_convert_algorithm_params(
key["classifier_params"]["clusterless_algorithm_params"]
)
)
except KeyError:
pass
Expand Down
1 change: 1 addition & 0 deletions src/spyglass/decoding/v0/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
speeds. eLife 10, e64505 (2021).

"""

import pprint

import datajoint as dj
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/decoding/v0/visualization_2D_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_static_track_animation(
"xmin": np.min(ul_corners[0]),
"xmax": np.max(ul_corners[0]) + track_rect_width,
"ymin": np.min(ul_corners[1]),
"ymax": np.max(ul_corners[1]) + track_rect_height
"ymax": np.max(ul_corners[1]) + track_rect_height,
# Speed: should this be displayed?
# TODO: Better approach for accommodating further data streams
}
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ def make(self, key):
vars(classifier).get("discrete_transition_coefficients_")
is not None
):
results[
"discrete_transition_coefficients"
] = classifier.discrete_transition_coefficients_
results["discrete_transition_coefficients"] = (
classifier.discrete_transition_coefficients_
)

# Insert results
# in future use https://github.com/rly/ndx-xarray and analysis nwb file?
Expand Down
1 change: 0 additions & 1 deletion src/spyglass/decoding/v1/dj_decoder_conversion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Converts decoder classes into dictionaries and dictionaries into classes
so that datajoint can store them in tables."""


import copy

import datajoint as dj
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def make(self, key):
vars(classifier).get("discrete_transition_coefficients_")
is not None
):
results[
"discrete_transition_coefficients"
] = classifier.discrete_transition_coefficients_
results["discrete_transition_coefficients"] = (
classifier.discrete_transition_coefficients_
)

# Insert results
# in future use https://github.com/rly/ndx-xarray and analysis nwb file?
Expand Down
18 changes: 9 additions & 9 deletions src/spyglass/position/v1/position_dlc_orient.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,15 @@ def interp_orientation(orientation, spans_to_interp, **kwargs):
# TODO: add parameters to refine interpolation
for ind, (span_start, span_stop) in enumerate(spans_to_interp):
if (span_stop + 1) >= len(orientation):
orientation.loc[
idx[span_start:span_stop], idx["orientation"]
] = np.nan
orientation.loc[idx[span_start:span_stop], idx["orientation"]] = (
np.nan
)
print(f"ind: {ind} has no endpoint with which to interpolate")
continue
if span_start < 1:
orientation.loc[
idx[span_start:span_stop], idx["orientation"]
] = np.nan
orientation.loc[idx[span_start:span_stop], idx["orientation"]] = (
np.nan
)
print(f"ind: {ind} has no startpoint with which to interpolate")
continue
orient = [
Expand All @@ -263,7 +263,7 @@ def interp_orientation(orientation, spans_to_interp, **kwargs):
xp=[start_time, stop_time],
fp=[orient[0], orient[-1]],
)
orientation.loc[
idx[start_time:stop_time], idx["orientation"]
] = orientnew
orientation.loc[idx[start_time:stop_time], idx["orientation"]] = (
orientnew
)
return orientation
20 changes: 10 additions & 10 deletions src/spyglass/position/v1/position_dlc_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,17 @@ def make(self, key):
description="video_frame_ind",
)
nwb_analysis_file = AnalysisNwbfile()
key[
"dlc_pose_estimation_position_object_id"
] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=position,
key["dlc_pose_estimation_position_object_id"] = (
nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=position,
)
)
key[
"dlc_pose_estimation_likelihood_object_id"
] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=likelihood,
key["dlc_pose_estimation_likelihood_object_id"] = (
nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=likelihood,
)
)
nwb_analysis_file.add(
nwb_file_name=key["nwb_file_name"],
Expand Down
20 changes: 10 additions & 10 deletions src/spyglass/position/v1/position_dlc_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,17 @@ def make(self, key):
comments="no comments",
description="video_frame_ind",
)
key[
"dlc_smooth_interp_position_object_id"
] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=position,
key["dlc_smooth_interp_position_object_id"] = (
nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=position,
)
)
key[
"dlc_smooth_interp_info_object_id"
] = nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=video_frame_ind,
key["dlc_smooth_interp_info_object_id"] = (
nwb_analysis_file.add_nwb_object(
analysis_file_name=key["analysis_file_name"],
nwb_object=video_frame_ind,
)
)
nwb_analysis_file.add(
nwb_file_name=key["nwb_file_name"],
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/sharing/sharing_kachery.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def set_resource_url(key: dict):
def reset_resource_url():
KacheryZone.reset_zone()
if default_kachery_resource_url is not None:
os.environ[
kachery_resource_url_envar
] = default_kachery_resource_url
os.environ[kachery_resource_url_envar] = (
default_kachery_resource_url
)


@schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,16 @@ def prepare_spikesortingview_data(
channel_neighborhood_size=channel_neighborhood_size,
)
if len(spike_train) >= 10:
unit_peak_channel_ids[
str(unit_id)
] = peak_channel_id
unit_peak_channel_ids[str(unit_id)] = (
peak_channel_id
)
else:
fallback_unit_peak_channel_ids[
str(unit_id)
] = peak_channel_id
unit_channel_neighborhoods[
str(unit_id)
] = channel_neighborhood
fallback_unit_peak_channel_ids[str(unit_id)] = (
peak_channel_id
)
unit_channel_neighborhoods[str(unit_id)] = (
channel_neighborhood
)
for unit_id in unit_ids:
peak_channel_id = unit_peak_channel_ids.get(str(unit_id), None)
if peak_channel_id is None:
Expand Down
6 changes: 3 additions & 3 deletions src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def _get_recording_timestamps(recording):

timestamps = np.zeros((total_frames,))
for i in range(recording.get_num_segments()):
timestamps[
cumsum_frames[i] : cumsum_frames[i + 1]
] = recording.get_times(segment_index=i)
timestamps[cumsum_frames[i] : cumsum_frames[i + 1]] = (
recording.get_times(segment_index=i)
)
else:
timestamps = recording.get_times()
return timestamps
Expand Down
9 changes: 6 additions & 3 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Helper functions for manipulating information from DataJoint fetch calls."""

import inspect
import os
from typing import Type
Expand Down Expand Up @@ -193,9 +194,11 @@ def get_child_tables(table):
return [
dj.FreeTable(
table.connection,
s
if not s.isdigit()
else next(iter(table.connection.dependencies.children(s))),
(
s
if not s.isdigit()
else next(iter(table.connection.dependencies.children(s)))
),
)
for s in table.children()
]
Loading
Loading