From 424d4cfc67e826f93716f860ff90cbfab165f2e2 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 26 Mar 2024 16:51:42 -0500 Subject: [PATCH] WIP: rebase Export process --- docs/mkdocs.yml | 2 + docs/src/misc/export.md | 126 ++++++++++ notebooks/05_Export.ipynb | 152 ++++++++++++ notebooks/py_scripts/05_Export.py | 94 +++++++ src/spyglass/common/common_lab.py | 2 +- src/spyglass/common/common_usage.py | 163 ++++++++++++ src/spyglass/settings.py | 8 + src/spyglass/utils/dj_graph.py | 342 ++++++++++++++++++++++++++ src/spyglass/utils/dj_helper_fn.py | 36 +-- src/spyglass/utils/dj_merge_tables.py | 11 +- src/spyglass/utils/dj_mixin.py | 165 ++++++++++++- 11 files changed, 1071 insertions(+), 30 deletions(-) create mode 100644 docs/src/misc/export.md create mode 100644 notebooks/05_Export.ipynb create mode 100644 notebooks/py_scripts/05_Export.py create mode 100644 src/spyglass/utils/dj_graph.py diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 996cb36dc..920b646a7 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -51,6 +51,7 @@ nav: - Data Sync: notebooks/02_Data_Sync.ipynb - Merge Tables: notebooks/03_Merge_Tables.ipynb - Config Populate: notebooks/04_PopulateConfigFile.ipynb + - Export: notebooks/05_Export.ipynb - Spikes: - Spike Sorting V0: notebooks/10_Spike_SortingV0.ipynb - Spike Sorting V1: notebooks/10_Spike_SortingV1.ipynb @@ -75,6 +76,7 @@ nav: - Insert Data: misc/insert_data.md - Merge Tables: misc/merge_tables.md - Database Management: misc/database_management.md + - Export: misc/export.md - API Reference: api/ # defer to gen-files + literate-nav - How to Contribute: contribute.md - Change Log: CHANGELOG.md diff --git a/docs/src/misc/export.md b/docs/src/misc/export.md new file mode 100644 index 000000000..e0cea5d08 --- /dev/null +++ b/docs/src/misc/export.md @@ -0,0 +1,126 @@ +# Export Process + +## Why + +DataJoint does not have any built-in functionality for exporting vertical slices +of a database. A lab can maintain a shared DataJoint pipeline across multiple +projects, but conforming to NIH data sharing guidelines may require that data +from only one project be shared during publication. + +## Requirements + +To export data with the current implementation, you must do the following: + +- All custom tables must inherit from `SpyglassMixin` (e.g., + `class MyTable(SpyglassMixin, dj.ManualOrOther):`) +- Only one export can be active at a time. +- Start the export process with `ExportSelection.start_export()`, run all + functions associated with a given analysis, and end the export process with + `ExportSelection.end_export()`. + +## How + +The current implementation relies on two classes in the `spyglass` package: +`SpyglassMixin` and `RestrGraph` and the `Export` tables. + +- `SpyglassMixin`: See `spyglass/utils/dj_mixin.py` +- `RestrGraph`: See `spyglass/utils/dj_graph.py` +- `Export`: See `spyglass/common/common_usage.py` + +### Mixin + +The `SpyglassMixin` class is a subclass of DataJoint's `Manual` class. A subset +of methods are used to set an environment variable, `SPYGLASS_EXPORT_ID`, and, +while active, intercept all `fetch`/`fetch_nwb` calls to tables. When `fetch` is +called, the mixin grabs the table name and the restriction applied to the table +and stores them in the `ExportSelection` part tables. + +- `fetch_nwb` is specific to Spyglass and logs all analysis nwb files that are + fetched. +- `fetch` is a DataJoint method that retrieves data from a table. + +### Graph + +The `RestrGraph` class uses DataJoint's networkx graph to store each of the +tables and restrictions intercepted by the `SpyglassMixin`'s `fetch` as +'leaves'. The class then cascades these restrictions up from each leaf to all +ancestors. Use is modeled in the methods of `ExportSelection`. + +```python +from spyglass.utils.dj_graph import RestrGraph + +restr_graph = RestrGraph(seed_table=AnyTable, leaves=None, verbose=False) +restr_graph.add_leaves( + leaves=[ + { + "table_name": MyTable.full_table_name, + "restriction": "any_restriction", + }, + { + "table_name": AnotherTable.full_table_name, + "restriction": "another_restriction", + }, + ] +) +restr_graph.cascade() +restricted_leaves = restr_graph.leaf_ft +all_restricted_tables = restr_graph.all_ft + +restr_graph.write_export(paper_id="my_paper_id") # part of `populate` below +``` + +By default, a `RestrGraph` object is created with a seed table to have access to +a DataJoint connection and graph. One or more leaves can be added at +initialization or later with the `add_leaves` method. The cascade process is +delayed until `cascade`, or another method that requires the cascade, is called. + +Cascading a single leaf involves transforming the leaf's restriction into its +parent's restriction, then repeating the process until all ancestors are +reached. If two leaves share a common ancestor, the restrictions are combined. +This process also accommodates projected fields, which appear as numeric alias +nodes in the graph. + +### Export Table + +The `ExportSelection` is where users should interact with this process. + +```python +from spyglass.common.common_usage import ExportSelection +from spyglass.common.common_usage import Export + +export_key = {paper_id: "my_paper_id", analysis_id: "my_analysis_id"} +ExportSelection().start_export(**export_key) +ExportSelection().restart_export(**export_key) # to clear previous attempt +analysis_data = (MyTable & my_restr).fetch() +analysis_nwb = (MyTable & my_restr).fetch_nwb() +ExportSelection().end_export() + +# Visual inspection +touched_files = DS().list_file_paths(**export_key) +restricted_leaves = DS().preview_tables(**export_key) + +# Export +Export().populate() +``` + +`Export` will invoke `RestrGraph.write_export` to collect cascaded restrictions +and file paths in its part tables, and write out a bash script to export the +data using a series of `mysqldump` commands. The script is saved to Spyglass's +directory, `base_dir/export/paper_id/`, using credentials from `dj_config`. To +use alternative credentials, create a +[mysql config file](https://dev.mysql.com/doc/refman/8.0/en/option-files.html). + +## External Implementation + +To implement an export for a non-Spyglass database, you will need to ... + +- Create a modified version of `SpyglassMixin`, including ... + - `_export_table` method to lazy load an export table like `ExportSelection` + - `export_id` attribute, plus setter and deleter methods, to manage the status + of the export. + - `fetch` and other methods to intercept and log exported content. +- Create a modified version of `ExportSelection`, that adjusts fields like + `spyglass_version` to match the new database. + +Or, optionally, you can use the `RestrGraph` class to cascade hand-picked tables +and restrictions without the background logging of `SpyglassMixin`. diff --git a/notebooks/05_Export.ipynb b/notebooks/05_Export.ipynb new file mode 100644 index 000000000..5ab5989b7 --- /dev/null +++ b/notebooks/05_Export.ipynb @@ -0,0 +1,152 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "# Export\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Intro\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_Developer Note:_ if you may make a PR in the future, be sure to copy this\n", + "notebook, and use the `gitignore` prefix `temp` to avoid future conflicts.\n", + "\n", + "This is one notebook in a multi-part series on Spyglass.\n", + "\n", + "- To set up your Spyglass environment and database, see\n", + " [the Setup notebook](./00_Setup.ipynb)\n", + "- To insert data, see [the Insert Data notebook](./01_Insert_Data.ipynb)\n", + "- For additional info on DataJoint syntax, including table definitions and\n", + " inserts, see\n", + " [these additional tutorials](https://github.com/datajoint/datajoint-tutorials)\n", + "- For information on what's goint on behind the scenes of an export, see\n", + " [documentation](https://lorenfranklab.github.io/spyglass/0.5/misc/export/)\n", + "\n", + "In short, Spyglass offers the ability to generate exports of one or more subsets\n", + "of the database required for a specific analysis as long as you do the following:\n", + "\n", + "- Inherit `SpyglassMixin` for all custom tables.\n", + "- Run only one export at a time.\n", + "- Start and stop each export logging process.\n", + "\n", + "**NOTE:** For demonstration purposes, this notebook relies on a more populated\n", + "database to highlight restriction merging capabilities of the export process.\n", + "Adjust the restrictions to suit your own dataset.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's start by importing the `spyglass` package, along with a few others.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-01-29 16:15:00,903][INFO]: Connecting root@localhost:3309\n", + "[2024-01-29 16:15:00,912][INFO]: Connected root@localhost:3309\n" + ] + } + ], + "source": [ + "import os\n", + "import datajoint as dj\n", + "\n", + "# change to the upper level folder to detect dj_local_conf.json\n", + "if os.path.basename(os.getcwd()) == \"notebooks\":\n", + " os.chdir(\"..\")\n", + "dj.config.load(\"dj_local_conf.json\") # load config for database connection info\n", + "\n", + "# ignore datajoint+jupyter async warnings\n", + "from spyglass.common.common_usage import Export, ExportSelection\n", + "from spyglass.lfp.analysis.v1 import LFPBandV1\n", + "from spyglass.position.v1 import TrodesPosV1\n", + "from spyglass.spikesorting.v1.curation import CurationV1\n", + "\n", + "# TODO: Add commentary, describe helpers on ExportSelection\n", + "\n", + "paper_key = {\"paper_id\": \"paper1\"}\n", + "ExportSelection().start_export(**paper_key, analysis_id=\"test1\")\n", + "a = (\n", + " LFPBandV1 & \"nwb_file_name LIKE 'med%'\" & {\"filter_name\": \"Theta 5-11 Hz\"}\n", + ").fetch()\n", + "b = (\n", + " LFPBandV1\n", + " & {\n", + " \"nwb_file_name\": \"mediumnwb20230802_.nwb\",\n", + " \"filter_name\": \"Theta 5-10 Hz\",\n", + " }\n", + ").fetch()\n", + "ExportSelection().start_export(**paper_key, analysis_id=\"test2\")\n", + "c = (CurationV1 & \"curation_id = 1\").fetch_nwb()\n", + "d = (TrodesPosV1 & 'trodes_pos_params_name = \"single_led\"').fetch()\n", + "ExportSelection().stop_export()\n", + "Export().populate_paper(**paper_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Up Next\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the [next notebook](./10_Spike_Sorting.ipynb), we'll start working with\n", + "ephys data with spike sorting.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "spy", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/py_scripts/05_Export.py b/notebooks/py_scripts/05_Export.py new file mode 100644 index 000000000..4acca2335 --- /dev/null +++ b/notebooks/py_scripts/05_Export.py @@ -0,0 +1,94 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: spy +# language: python +# name: python3 +# --- + +# # Export +# + +# ## Intro +# + +# _Developer Note:_ if you may make a PR in the future, be sure to copy this +# notebook, and use the `gitignore` prefix `temp` to avoid future conflicts. +# +# This is one notebook in a multi-part series on Spyglass. +# +# - To set up your Spyglass environment and database, see +# [the Setup notebook](./00_Setup.ipynb) +# - To insert data, see [the Insert Data notebook](./01_Insert_Data.ipynb) +# - For additional info on DataJoint syntax, including table definitions and +# inserts, see +# [these additional tutorials](https://github.com/datajoint/datajoint-tutorials) +# - For information on what's goint on behind the scenes of an export, see +# [documentation](https://lorenfranklab.github.io/spyglass/0.5/misc/export/) +# +# In short, Spyglass offers the ability to generate exports of one or more subsets +# of the database required for a specific analysis as long as you do the following: +# +# - Inherit `SpyglassMixin` for all custom tables. +# - Run only one export at a time. +# - Start and stop each export logging process. +# +# **NOTE:** For demonstration purposes, this notebook relies on a more populated +# database to highlight restriction merging capabilities of the export process. +# Adjust the restrictions to suit your own dataset. +# + +# ## Imports +# + +# Let's start by importing the `spyglass` package, along with a few others. +# + +# + +import os +import datajoint as dj + +# change to the upper level folder to detect dj_local_conf.json +if os.path.basename(os.getcwd()) == "notebooks": + os.chdir("..") +dj.config.load("dj_local_conf.json") # load config for database connection info + +# ignore datajoint+jupyter async warnings +from spyglass.common.common_usage import Export, ExportSelection +from spyglass.lfp.analysis.v1 import LFPBandV1 +from spyglass.position.v1 import TrodesPosV1 +from spyglass.spikesorting.v1.curation import CurationV1 + +# TODO: Add commentary, describe helpers on ExportSelection + +paper_key = {"paper_id": "paper1"} +ExportSelection().start_export(**paper_key, analysis_id="test1") +a = ( + LFPBandV1 & "nwb_file_name LIKE 'med%'" & {"filter_name": "Theta 5-11 Hz"} +).fetch() +b = ( + LFPBandV1 + & { + "nwb_file_name": "mediumnwb20230802_.nwb", + "filter_name": "Theta 5-10 Hz", + } +).fetch() +ExportSelection().start_export(**paper_key, analysis_id="test2") +c = (CurationV1 & "curation_id = 1").fetch_nwb() +d = (TrodesPosV1 & 'trodes_pos_params_name = "single_led"').fetch() +ExportSelection().stop_export() +Export().populate_paper(**paper_key) +# - + +# ## Up Next +# + +# In the [next notebook](./10_Spike_Sorting.ipynb), we'll start working with +# ephys data with spike sorting. +# diff --git a/src/spyglass/common/common_lab.py b/src/spyglass/common/common_lab.py index a6a162b2b..c5a6fbc00 100644 --- a/src/spyglass/common/common_lab.py +++ b/src/spyglass/common/common_lab.py @@ -92,7 +92,7 @@ def _load_admin(cls): """Load admin list.""" cls._admin = list( (cls.LabMemberInfo & {"admin": True}).fetch("datajoint_user_name") - ) + ) + ["root"] @property def admin(cls) -> list: diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index fdf7ae99d..47b5172c9 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -6,8 +6,15 @@ plan future development of Spyglass. """ +from typing import Union + import datajoint as dj +from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile +from spyglass.utils import SpyglassMixin, logger +from spyglass.utils.dj_graph import RestrGraph +from spyglass.utils.dj_helper_fn import unique_dicts + schema = dj.schema("common_usage") @@ -37,3 +44,159 @@ class InsertError(dj.Manual): error_message: varchar(255) error_raw = null: blob """ + + +@schema +class ExportSelection(SpyglassMixin, dj.Manual): + definition = """ + export_id: int auto_increment + --- + paper_id: varchar(32) + analysis_id: varchar(32) + spyglass_version: varchar(16) + time=CURRENT_TIMESTAMP: timestamp + unique index (paper_id, analysis_id) + """ + + class Table(SpyglassMixin, dj.Part): + definition = """ + -> master + table_id: int + --- + table_name: varchar(64) + restriction: varchar(2048) + """ + + def insert1(self, key, **kwargs): + key = self._auto_increment(key, pk="table_id") + super().insert1(key, **kwargs) + + class File(SpyglassMixin, dj.Part): + definition = """ + -> master + -> AnalysisNwbfile + """ + # Note: only tracks AnalysisNwbfile. list_file_paths also grabs Nwbfile. + + def insert1_return_pk(self, key: dict, **kwargs) -> int: + """Custom insert to return export_id.""" + status = "Resuming" + if not (query := self & key): + super().insert1(key, **kwargs) + status = "Starting" + export_id = query.fetch1("export_id") + export_key = {"export_id": export_id} + if query := (Export & export_key): + query.super_delete(warn=False) + logger.info(f"{status} {export_key}") + return export_id + + def start_export(self, paper_id, analysis_id) -> None: + self._start_export(paper_id, analysis_id) + + def stop_export(self) -> None: + self._stop_export() + + # NOTE: These helpers could be moved to table below, but I think + # end users may want to use them to check what's in the export log + # before actually exporting anything, which is more associated with + # Selection + + def list_file_paths(self, key: dict) -> list[str]: + file_table = self.File & key + analysis_fp = [ + AnalysisNwbfile().get_abs_path(fname) + for fname in file_table.fetch("analysis_file_name") + ] + nwbfile_fp = [ + Nwbfile().get_abs_path(fname) + for fname in (AnalysisNwbfile * file_table).fetch("nwb_file_name") + ] + return [{"file_path": p} for p in list({*analysis_fp, *nwbfile_fp})] + + def get_restr_graph(self, key: dict) -> RestrGraph: + leaves = unique_dicts( + (self.Table & key).fetch("table_name", "restriction", as_dict=True) + ) + return RestrGraph(seed_table=self, leaves=leaves, verbose=True) + + def preview_tables(self, key: dict) -> list[dj.FreeTable]: + return self.get_restr_graph(key).leaf_ft + + def _min_export_id(self, paper_id: str) -> int: + """Return all export_ids for a paper.""" + if isinstance(paper_id, dict): + paper_id = paper_id.get("paper_id") + if not (query := self & {"paper_id": paper_id}): + return None + return min(query.fetch("export_id")) + + def paper_export_id(self, paper_id: str) -> dict: + """Return the minimum export_id for a paper, used to populate Export.""" + return {"export_id": self._min_export_id(paper_id)} + + +@schema +class Export(SpyglassMixin, dj.Computed): + definition = """ + -> ExportSelection + """ + + # In order to get a many-to-one relationship btwn Selection and Export, + # we ignore all but the first export_id. + + class Table(SpyglassMixin, dj.Part): + definition = """ + -> master + table_id: int + --- + table_name: varchar(64) + restriction: varchar(2048) + unique index (table_name) + """ + + class File(SpyglassMixin, dj.Part): + definition = """ + -> master + file_id: int + --- + file_path: varchar(255) + """ + # What's needed? full path? relative path? + + def populate_paper(self, paper_id: Union[str, dict]): + if isinstance(paper_id, dict): + paper_id = paper_id.get("paper_id") + self.populate(ExportSelection().paper_export_id(paper_id)) + + def make(self, key): + query = ExportSelection & key + paper_key = query.fetch("paper_id", as_dict=True)[0] + + # Null insertion if export_id is not the minimum for the paper + min_export_id = query._min_export_id(paper_key) + if key.get("export_id") != min_export_id: + logger.info( + f"Skipping export_id {key['export_id']}, use {min_export_id}" + ) + self.insert1(key) + return + + restr_graph = query.get_restr_graph(paper_key) + file_paths = query.list_file_paths(paper_key) + + table_inserts = [ + {**key, **rd, "table_id": i} + for i, rd in enumerate(restr_graph.as_dict) + ] + file_inserts = [ + {**key, **fp, "file_id": i} for i, fp in enumerate(file_paths) + ] + + # Writes but does not run mysqldump. Assumes single version per paper. + version_key = query.fetch("spyglass_version", as_dict=True)[0] + restr_graph.write_export(**paper_key, **version_key) + + self.insert1(key) + self.Table().insert(table_inserts) # TODO: Duplicate error?? + self.File().insert(file_inserts) diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index af16e688d..4900f595d 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -68,6 +68,7 @@ def __init__(self, base_dir: str = None, **kwargs): "waveforms": "waveforms", "temp": "tmp", "video": "video", + "export": "export", }, "kachery": { "cloud": ".kachery-cloud", @@ -459,6 +460,7 @@ def _dj_custom(self) -> dict: "waveforms": self.waveforms_dir, "temp": self.temp_dir, "video": self.video_dir, + "export": self.export_dir, }, "kachery_dirs": { "cloud": self.config.get( @@ -516,6 +518,10 @@ def temp_dir(self) -> str: def video_dir(self) -> str: return self.config.get(self.dir_to_var("video")) + @property + def export_dir(self) -> str: + return self.config.get(self.dir_to_var("export")) + @property def debug_mode(self) -> bool: """Returns True if debug_mode is set. @@ -560,6 +566,7 @@ def dlc_output_dir(self) -> str: sorting_dir = None waveforms_dir = None video_dir = None + export_dir = None dlc_project_dir = None dlc_video_dir = None dlc_output_dir = None @@ -573,6 +580,7 @@ def dlc_output_dir(self) -> str: sorting_dir = sg_config.sorting_dir waveforms_dir = sg_config.waveforms_dir video_dir = sg_config.video_dir + export_dir = sg_config.export_dir debug_mode = sg_config.debug_mode test_mode = sg_config.test_mode prepopulate = config.get("prepopulate", False) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py new file mode 100644 index 000000000..5507eefe5 --- /dev/null +++ b/src/spyglass/utils/dj_graph.py @@ -0,0 +1,342 @@ +"""DataJoint graph traversal and restriction application. + +Note: read `ft` as FreeTable and `restr` as restriction.""" + +from pathlib import Path +from typing import Dict, List + +from datajoint import FreeTable +from datajoint import config as dj_config +from datajoint.condition import make_condition +from datajoint.table import Table + +from spyglass.settings import export_dir +from spyglass.utils import logger +from spyglass.utils.dj_helper_fn import unique_dicts + + +class RestrGraph: + def __init__( + self, + seed_table: Table, + table_name: str = None, + restriction: str = None, + leaves: List[Dict[str, str]] = None, + verbose: bool = False, + **kwargs, + ): + """Use graph to cascade restrictions up from leaves to all ancestors. + + Parameters + ---------- + seed_table : Table + Table to use to establish connection and graph + table_name : str, optional + Table name of single leaf, default None + restriction : str, optional + Restriction to apply to leaf. default None + leaves : Dict[str, str], optional + List of dictionaries with keys table_name and restriction. One + entry per leaf node. Default None. + verbose : bool, optional + Whether to print verbose output. Default False + """ + + self.connection = seed_table.connection + self.graph = seed_table.connection.dependencies + self.graph.load() + + self.verbose = verbose + self.cascaded = False + self.ancestors = set() + self.visited = set() + self.leaves = set() + + if table_name and restriction: + self.add_leaf(table_name, restriction) + if leaves: + self.add_leaves(leaves) + + def __repr__(self): + l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" + processed = "Processed" if self.cascaded else "Unprocessed" + return f"{processed} RestrictionGraph(\n\t{l_str})" + + @property + def all_ft(self): + """Get restricted FreeTables from all visited nodes.""" + return [self._get_ft(table, with_restr=True) for table in self.visited] + + @property + def leaf_ft(self): + """Get restricted FreeTables from graph leaves.""" + return [self._get_ft(table, with_restr=True) for table in self.leaves] + + def _get_node(self, table): + """Get node from graph.""" + node = self.graph.nodes.get(table) + if not node: + raise ValueError( + f"Table {table} not found in graph." + + "\n\tPlease import this table and rerun" + ) + return node + + def _set_node(self, table, attr="ft", value=None): + """Set attribute on graph node.""" + _ = self._get_node(table) # Ensure node exists + self.graph.nodes[table][attr] = value + + def _get_ft(self, table, with_restr=False): + """Get FreeTable from graph node. If one doesn't exist, create it.""" + restr = self._get_restr(table) if with_restr else True + if ft := self._get_node(table).get("ft"): + return ft & restr + ft = FreeTable(self.connection, table) + self._set_node(table, "ft", ft) + return ft & restr + + def _get_restr(self, table): + """Get restriction from graph node.""" + table = table if isinstance(table, str) else table.full_table_name + return self._get_node(table).get("restr") + + def _set_restr(self, table, restriction): + """Add restriction to graph node. If one exists, merge with new.""" + ft = self._get_ft(table) + restriction = ( # Convert to condition if list or dict + make_condition(ft, restriction, set()) + if not isinstance(restriction, str) + else restriction + ) + if existing := self._get_restr(table): + if existing == restriction: + return + join = ft & [existing, restriction] + if len(join) == len(ft & existing): + return # restriction is a subset of existing + restriction = unique_dicts(join.fetch("KEY", as_dict=True)) + + self._log_truncate( + f"Set Restr {table}: {type(restriction)} {restriction}" + ) + self._set_node(table, "restr", restriction) + + def _log_truncate(self, log_str, max_len=80): + """Truncate log lines to max_len and print if verbose.""" + if not self.verbose: + return + logger.info( + log_str[:max_len] + "..." if len(log_str) > max_len else log_str + ) + + def _child_to_parent( + self, + child, + parent, + restriction, + attr_map=None, + primary=True, + **kwargs, + ) -> List[Dict[str, str]]: + """Given a child, child's restr, and parent, get parent's restr. + + Parameters + ---------- + child : str + child table name + parent : str + parent table name + restriction : str + restriction to apply to child + attr_map : dict, optional + dictionary mapping aliases across parend/child, as pulled from + DataJoint-assembled graph. Default None. Func will flip this dict + to convert from child to parent fields. + primary : bool, optional + Is parent in child's primary key? Default True. Also derived from + DataJoint-assembled graph. If True, project only primary key fields + to avoid secondary key collisions. + + Returns + ------- + List[Dict[str, str]] + List of dicts containing primary key fields for restricted parent + table. + """ + + # Need to flip attr_map to respect parent's fields + attr_map = ( + {v: k for k, v in attr_map.items() if k != k} if attr_map else {} + ) + child_ft = self._get_ft(child) + parent_ft = self._get_ft(parent).proj() + restr = restriction or self._get_restr(child_ft) or True + restr_child = child_ft & restr + + if primary: # Project only primary key fields to avoid collisions + join = restr_child.proj(**attr_map) * parent_ft + else: # Include all fields + join = restr_child.proj(..., **attr_map) * parent_ft + + ret = unique_dicts(join.fetch(*parent_ft.primary_key, as_dict=True)) + + if len(ret) == len(parent_ft): + self._log_truncate(f"NULL rest {parent}") + + return ret + + def cascade1(self, table, restriction): + """Cascade a restriction up the graph, recursively on parents. + + Parameters + ---------- + table : str + table name + restriction : str + restriction to apply + """ + self._set_restr(table, restriction) + self.visited.add(table) + + for parent, data in self.graph.parents(table).items(): + if parent in self.visited: + continue + + if parent.isnumeric(): + parent, data = self.graph.parents(parent).popitem() + + parent_restr = self._child_to_parent( + child=table, + parent=parent, + restriction=restriction, + **data, + ) + + self.cascade1(parent, parent_restr) # Parent set on recursion + + def cascade(self) -> None: + """Cascade all restrictions up the graph.""" + for table in self.leaves - self.visited: + restr = self._get_restr(table) + self._log_truncate(f"Start {table}: {restr}") + self.cascade1(table, restr) + if not self.visited == self.ancestors: + raise RuntimeError("Cascade: FAIL - incomplete cascade") + + self.cascaded = True + + def add_leaf(self, table_name, restriction, cascade=False) -> None: + """Add leaf to graph and cascade if requested. + + Parameters + ---------- + table_name : str + table name of leaf + restriction : str + restriction to apply to leaf + """ + new_ancestors = set(self._get_ft(table_name).ancestors()) + self.ancestors |= new_ancestors # Add to total ancestors + self.visited -= new_ancestors # Remove from visited to revisit + + self.leaves.add(table_name) + self._set_restr(table_name, restriction) # Redundant if cascaded + + if cascade: + self.cascade1(table_name, restriction) + self.cascaded = True + + def add_leaves(self, leaves: List[Dict[str, str]], cascade=False) -> None: + """Add leaves to graph and cascade if requested. + + Parameters + ---------- + leaves : List[Dict[str, str]] + list of dictionaries containing table_name and restriction + cascade : bool, optional + Whether to cascade the restrictions up the graph. Default False + """ + + if not leaves: + return + if not isinstance(leaves, list): + leaves = [leaves] + leaves = unique_dicts(leaves) + for leaf in leaves: + if not ( + (table_name := leaf.get("table_name")) + and (restriction := leaf.get("restriction")) + ): + raise ValueError( + f"Leaf must have table_name and restriction: {leaf}" + ) + self.add_leaf(table_name, restriction, cascade=False) + if cascade: + self.cascade() + + @property + def as_dict(self) -> List[Dict[str, str]]: + """Return as a list of dictionaries of table_name: restriction""" + if not self.cascaded: + self.cascade() + return [ + {"table_name": table, "restriction": self._get_restr(table)} + for table in self.ancestors + if self._get_restr(table) + ] + + def _write_sql_cnf(self): + """Write SQL cnf file to avoid password prompt.""" + cnf_path = Path("~/.my.cnf").expanduser() + + if cnf_path.exists(): + return + + with open(str(cnf_path), "w") as file: + file.write( + "[client]\n" + + "user={}\n".format(dj_config["database.user"]) + + "password={}\n".format(dj_config["database.password"]) + + "host={}\n".format(dj_config["database.host"]) + ) + + def _write_mysqldump( + self, paper_id: str, docker_id=None, spyglass_version=None + ): + """Write mysqlmdump to a temporary file and return the file object""" + paper_dir = Path(export_dir) / paper_id + paper_dir.mkdir(exist_ok=True) + + dump_script = paper_dir / f"_ExportSQL_{paper_id}.sh" + dump_content = paper_dir / f"_Populate_{paper_id}.sql" + + prefix = f"docker exec -i {docker_id}" if docker_id else "" + version = ( # Include spyglass version as comment in dump + f"echo '-- SPYGLASS VERSION: {spyglass_version}' > {dump_content}\n" + if spyglass_version + else "" + ) + + with open(dump_script, "w") as file: + file.write(f"#!/bin/bash\n{version}") + + for table in self.all_ft: + if not (where := table.where_clause()): + continue + database, table_name = table.full_table_name.split(".") + file.write( + f"{prefix}mysqldump {database} {table_name} " + + f'--where="{where}" >> {dump_content}\n' + ) + logger.info(f"Export script written to {dump_script}") + + def write_export( + self, paper_id: str, docker_id=None, spyglass_version=None + ): + if not self.cascaded: + self.cascade() + self._write_sql_cnf() + self._write_mysqldump(paper_id, docker_id, spyglass_version) + + # TODO: export conda env diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 4a0495778..de9f2d3a9 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -11,6 +11,11 @@ from spyglass.utils.nwb_helper_fn import get_nwb_file +def unique_dicts(list_of_dict): + """Remove duplicate dictionaries from a list.""" + return [dict(t) for t in {tuple(d.items()) for d in list_of_dict}] + + def deprecated_factory(classes: list, old_module: str = "") -> list: """Creates a list of classes and logs a warning when instantiated @@ -127,33 +132,30 @@ def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs): nwb_objects : list List of dicts containing fetch results and NWB objects. """ - kwargs["as_dict"] = True # force return as dictionary - tbl, attr_name = nwb_master - - if not attrs: - attrs = query_expression.heading.names - - # get the list of analysis or nwb files - file_name_str = ( - "analysis_file_name" if "analysis" in nwb_master[1] else "nwb_file_name" - ) # TODO: avoid this import? from ..common.common_nwbfile import AnalysisNwbfile, Nwbfile - file_path_fn = ( - AnalysisNwbfile.get_abs_path - if "analysis" in nwb_master[1] - else Nwbfile.get_abs_path - ) + kwargs["as_dict"] = True # force return as dictionary + attrs = attrs or query_expression.heading.names # if none, all + + tbl, attr_name = nwb_master + + which = "analysis" if "analysis" in attr_name else "nwb" + tbl_map = { # map to file_name_str and file_path_fn + "analysis": ["analysis_file_name", AnalysisNwbfile.get_abs_path], + "nwb": ["nwb_file_name", Nwbfile.get_abs_path], + } + file_name_str, file_path_fn = tbl_map[which] # TODO: check that the query_expression restricts tbl - CBroz nwb_files = ( query_expression * tbl.proj(nwb2load_filepath=attr_name) ).fetch(file_name_str) + for file_name in nwb_files: file_path = file_path_fn(file_name) - if not os.path.exists(file_path): - # retrieve the file from kachery. This also opens the file and stores the file object + if not os.path.exists(file_path): # retrieve the file from kachery. + # This also opens the file and stores the file object get_nwb_file(file_path) rec_dicts = ( diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 37dc62fe0..e7ba23c72 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -784,13 +784,14 @@ def delete(self, force_permission=False, *args, **kwargs): ): part.delete(force_permission=force_permission, *args, **kwargs) - def super_delete(self, *args, **kwargs): + def super_delete(self, warn=True, *args, **kwargs): """Alias for datajoint.table.Table.delete. - Added to support MRO of SpyglassMixin""" - logger.warning("!! Using super_delete. Bypassing cautious_delete !!") - - self._log_use(start=time(), super_delete=True) + Added to support MRO of SpyglassMixin + """ + if warn: + logger.warning("!! Bypassing cautious_delete !!") + self._log_use(start=time(), super_delete=True) super().delete(*args, **kwargs) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 29978ae88..96555f2fe 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,9 +1,13 @@ +from atexit import register as exit_register +from atexit import unregister as exit_unregister from collections import OrderedDict from functools import cached_property +from os import environ from time import time from typing import Dict, List, Union import datajoint as dj +from datajoint.condition import make_condition from datajoint.errors import DataJointError from datajoint.expression import QueryExpression from datajoint.logging import logger as dj_logger @@ -17,6 +21,8 @@ from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK from spyglass.utils.logging import logger +EXPORT_ENV_VAR = "SPYGLASS_EXPORT_ID" + class SpyglassMixin: """Mixin for Spyglass DataJoint tables. @@ -119,7 +125,27 @@ def fetch_nwb(self, *attrs, **kwargs): AnalysisNwbfile (i.e., "-> (Analysis)Nwbfile" in definition) or a _nwb_table attribute. If both are present, the attribute takes precedence. + + Additional logic support Export table logging. """ + table, tbl_attr = self._nwb_table_tuple + if self.export_id and "analysis" in tbl_attr: + logger.info(f"Export {self.export_id}: fetch_nwb {self.table_name}") + tbl_pk = "analysis_file_name" + fname = (self * table).fetch1(tbl_pk) + self._export_table.File.insert1( + {"export_id": self.export_id, tbl_pk: fname}, + skip_duplicates=True, + ) + self._export_table.Table.insert1( + dict( + export_id=self.export_id, + table_name=self.full_table_name, + restriction=make_condition(self, self.restriction, set()), + ), + skip_duplicates=True, + ) + return fetch_nwb(self, self._nwb_table_tuple, *attrs, **kwargs) # ------------------------ delete_downstream_merge ------------------------ @@ -318,9 +344,7 @@ def _get_exp_summary(self): empty_pk = {self._member_pk: "NULL"} format = dj.U(self._session_pk, self._member_pk) - sess_link = self._session_connection.join( - self.restriction, reverse_order=True - ) + sess_link = self._session_connection.join(self.restriction) exp_missing = format & (sess_link - SesExp).proj(**empty_pk) exp_present = format & (sess_link * SesExp - exp_missing).proj() @@ -330,7 +354,9 @@ def _get_exp_summary(self): @cached_property def _session_connection(self) -> Union[TableChain, bool]: """Path from Session table to self. False if no connection found.""" - connection = TableChain(parent=self._delete_deps[-1], child=self) + connection = TableChain( + parent=self._delete_deps[-1], child=self, reverse=True + ) return connection if connection.has_link else False @cached_property @@ -488,8 +514,133 @@ def delete(self, force_permission=False, *args, **kwargs): """Alias for cautious_delete, overwrites datajoint.table.Table.delete""" self.cautious_delete(force_permission=force_permission, *args, **kwargs) - def super_delete(self, *args, **kwargs): + def super_delete(self, warn=True, *args, **kwargs): """Alias for datajoint.table.Table.delete.""" - logger.warning("!! Using super_delete. Bypassing cautious_delete !!") - self._log_use(start=time(), super_delete=True) + if warn: + logger.warning("!! Bypassing cautious_delete !!") + self._log_use(start=time(), super_delete=True) super().delete(*args, **kwargs) + + # ------------------------------- Export Log ------------------------------- + + @cached_property + def _spyglass_version(self): + """Get Spyglass version from dj.config.""" + from spyglass import __version__ as sg_version + + return ".".join(sg_version.split(".")[:3]) # Major.Minor.Patch + + @cached_property + def _export_table(self): + """Lazy load export selection table.""" + from spyglass.common.common_usage import ExportSelection + + return ExportSelection() + + @property + def export_id(self): + """ID of export in progress. + + NOTE: User of an env variable to store export_id may not be thread safe. + Exports must be run in sequence, not parallel. + """ + + return int(environ.get(EXPORT_ENV_VAR, 0)) + + @export_id.setter + def export_id(self, value): + """Set ID of export using `table.export_id = X` notation.""" + if self.export_id != 0 and self.export_id != value: + raise RuntimeError("Export already in progress.") + environ[EXPORT_ENV_VAR] = str(value) + exit_register(self._export_id_cleanup) # End export on exit + + @export_id.deleter + def export_id(self): + """Delete ID of export using `del table.export_id` notation.""" + self._export_id_cleanup() + + def _export_id_cleanup(self): + """Cleanup export ID.""" + if environ.get(EXPORT_ENV_VAR): + del environ[EXPORT_ENV_VAR] + exit_unregister(self._export_id_cleanup) # Remove exit hook + + def _start_export(self, paper_id, analysis_id): + """Start export process.""" + if self.export_id: + logger.info(f"Export {self.export_id} in progress. Starting new.") + self._stop_export(warn=False) + + self.export_id = self._export_table.insert1_return_pk( + dict( + paper_id=paper_id, + analysis_id=analysis_id, + spyglass_version=self._spyglass_version, + ) + ) + + def _stop_export(self, warn=True): + """End export process.""" + if not self.export_id and warn: + logger.warning("Export not in progress.") + del self.export_id + + def _log_fetch(self): + """Log fetch for export.""" + if ( + not self.export_id + or self.full_table_name == self._export_table.full_table_name + or "dandi_export" in self.full_table_name # for populated table + ): + return + logger.info(f"Export {self.export_id}: fetch() {self.table_name}") + restr_str = make_condition(self, self.restriction, set()) + if isinstance(restr_str, str) and len(restr_str) > 2048: + raise RuntimeError( + "DandiExport cannot handle restrictions > 2048.\n\t" + + "If required, please open an issue on GitHub.\n\t" + + f"Restriction: {restr_str}" + ) + self._export_table.Table.insert1( + dict( + export_id=self.export_id, + table_name=self.full_table_name, + restriction=make_condition(self, restr_str, set()), + ), + skip_duplicates=True, + ) + + def fetch(self, *args, **kwargs): + """Log fetch for export.""" + ret = super().fetch(*args, **kwargs) + self._log_fetch() + return ret + + def fetch1(self, *args, **kwargs): + """Log fetch1 for export.""" + ret = super().fetch1(*args, **kwargs) + self._log_fetch() + return ret + + # ------------------------- Other helper methods ------------------------- + + def _auto_increment(self, key, pk, *args, **kwargs): + """Auto-increment primary key.""" + if not key.get(pk): + key[pk] = (dj.U().aggr(self, n=f"max({pk})").fetch1("n") or 0) + 1 + return key + + def file_like(self, name=None, **kwargs): + """Convenience method for wildcard search on file name fields.""" + if not name: + return self & True + attr = None + for field in self.heading.names: + if "file" in field: + attr = field + break + if not attr: + logger.error(f"No file-like field found in {self.full_table_name}") + return + return self & f"{attr} LIKE '%{name}%'"