Skip to content

Commit

Permalink
Print session/exp summary, fix lfp target sample fetch (#762)
Browse files Browse the repository at this point in the history
* Print session/exp summary, fix lfp target sample fetch

* Update changelog

* Fix merge conflict reverts
  • Loading branch information
CBroz1 authored Jan 12, 2024
1 parent 582a09c commit e759bce
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 27 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
- Clean up following pre-commit checks. #688
- Add Mixin class to centralize `fetch_nwb` functionality. #692, #734
- Refactor restriction use in `delete_downstream_merge` #703
- Add `cautious_delete` to Mixin class, initial implementation. #711
- Add `cautious_delete` to Mixin class, initial implementation. #711, #762
- Add `deprecation_factory` to facilitate table migration. #717
- Add Spyglass logger. #730
- IntervalList: Add secondary key `pipeline` #742
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/common/common_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get_djuser_name(cls, dj_user) -> str:
if len(query) != 1:
raise ValueError(
f"Could not find name for datajoint user {dj_user}"
+ f"in common.LabMember.LabMemberInfo: {query}"
+ f" in common.LabMember.LabMemberInfo: {query}"
)

return query[0]
Expand Down
5 changes: 4 additions & 1 deletion src/spyglass/lfp/v1/lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def make(self, key):
"sampling_rate", "interval_list_name"
)
sampling_rate = int(np.round(sampling_rate))
target_sampling_rate = (LFPSelection & key).fetch1(
"target_sampling_rate"
)

# to get the list of valid times, we need to combine those from the user with those from the
# raw data
Expand All @@ -96,7 +99,7 @@ def make(self, key):
+ f"{MIN_LFP_INTERVAL_DURATION} sec long."
)
# target user-specified sampling rate
decimation = sampling_rate // key["target_sampling_rate"]
decimation = int(sampling_rate // target_sampling_rate)

# get the LFP filter that matches the raw data
filter = (
Expand Down
75 changes: 51 additions & 24 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class SpyglassMixin:
_nwb_table_dict = {}
_delete_dependencies = []
_merge_delete_func = None
_session_pk = None
_member_pk = None

# ------------------------------- fetch_nwb -------------------------------

Expand Down Expand Up @@ -103,6 +105,8 @@ def _delete_deps(self) -> list:
from spyglass.common import LabMember, LabTeam, Session # noqa F401

self._delete_dependencies = [LabMember, LabTeam, Session]
self._session_pk = Session.primary_key[0]
self._member_pk = LabMember.primary_key[0]
return self._delete_dependencies

@property
Expand All @@ -119,10 +123,9 @@ def _merge_del_func(self) -> callable:
self._merge_delete_func = delete_downstream_merge
return self._merge_delete_func

def _find_session(
def _find_session_link(
self,
table: dj.user_tables.UserTable,
Session: dj.user_tables.UserTable,
search_limit: int = 2,
) -> dj.expression.QueryExpression:
"""Find Session table associated with table.
Expand All @@ -141,26 +144,47 @@ def _find_session(
datajoint.expression.QueryExpression or None
Join of table link with Session table if found, else None.
"""
Session = self._delete_deps[-1]
# TODO: check search_limit default is enough for any table in spyglass
if self.full_table_name == Session.full_table_name:
# if self is Session, return self
return self

elif (
# if Session is not in ancestors of table, search children
Session.full_table_name not in table.ancestors()
and search_limit > 0 # prevent infinite recursion
):
for child in table.children():
table = self._find_session(child, Session, search_limit - 1)
if self._session_pk in table.primary_key:
# joinable with Session
return table * Session

elif search_limit > 0:
for child in table.children(as_objects=True):
table = self._find_session_link(child, search_limit - 1)
if table: # table is link, will valid join to Session
break
return table

elif search_limit < 1: # if no session ancestor found and limit reached
elif not table or search_limit < 1: # if none found and limit reached
return # Err kept in parent func to centralize permission logic

return table * Session

def _get_exp_summary(self, sess_link: dj.expression.QueryExpression):
"""Get summary of experimenters for session(s), including NULL.
Parameters
----------
sess_link : datajoint.expression.QueryExpression
Join of table link with Session table.
Returns
-------
str
Summary of experimenters for session(s).
"""
Session = self._delete_deps[-1]

format = dj.U(self._session_pk, self._member_pk)
exp_missing = format & (sess_link - Session.Experimenter).proj(
**{self._member_pk: "NULL"}
)
exp_present = (
format & (sess_link * Session.Experimenter - exp_missing).proj()
)
return exp_missing + exp_present

def _check_delete_permission(self) -> None:
"""Check user name against lab team assoc. w/ self * Session.
Expand All @@ -181,32 +205,35 @@ def _check_delete_permission(self) -> None:
if dj_user in LabMember().admin: # bypass permission check for admin
return

sess = self._find_session(self, Session)
if not sess: # Permit delete if not linked to a session
sess_link = self._find_session_link(table=self)
if not sess_link: # Permit delete if not linked to a session
logger.warn(
"Could not find lab team associated with "
+ f"{self.__class__.__name__}."
+ "\nBe careful not to delete others' data."
)
return

experimenters = (sess * Session.Experimenter).fetch("lab_member_name")
if len(experimenters) < len(sess):
# TODO: adjust to check each session individually? Expensive but
# prevents against edge case of one sess with mult and another
# with none
sess_summary = self._get_exp_summary(
sess_link.restrict(self.restriction)
)
experimenters = sess_summary.fetch(self._member_pk)
if None in experimenters:
raise PermissionError(
f"Please ensure all Sessions have an experimenter:\n{sess}"
"Please ensure all Sessions have an experimenter in "
+ f"SessionExperimenter:\n{sess_summary}"
)

user_name = LabMember().get_djuser_name(dj_user)
for experimenter in set(experimenters):
if user_name not in LabTeam().get_team_members(experimenter):
sess_w_exp = sess_summary & {self._member_pk: experimenter}
raise PermissionError(
f"User '{user_name}' is not on a team with '{experimenter}'"
+ ", an experimenter for session(s):\n"
+ f"{sess * Session.Experimenter}"
+ f"{sess_w_exp}"
)
logger.info(f"Queueing delete for session(s):\n{sess_summary}")

# Rename to `delete` when we're ready to use it
# TODO: Intercept datajoint delete confirmation prompt for merge deletes
Expand Down

0 comments on commit e759bce

Please sign in to comment.