From 424d4cfc67e826f93716f860ff90cbfab165f2e2 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 26 Mar 2024 16:51:42 -0500 Subject: [PATCH 01/15] WIP: rebase Export process --- docs/mkdocs.yml | 2 + docs/src/misc/export.md | 126 ++++++++++ notebooks/05_Export.ipynb | 152 ++++++++++++ notebooks/py_scripts/05_Export.py | 94 +++++++ src/spyglass/common/common_lab.py | 2 +- src/spyglass/common/common_usage.py | 163 ++++++++++++ src/spyglass/settings.py | 8 + src/spyglass/utils/dj_graph.py | 342 ++++++++++++++++++++++++++ src/spyglass/utils/dj_helper_fn.py | 36 +-- src/spyglass/utils/dj_merge_tables.py | 11 +- src/spyglass/utils/dj_mixin.py | 165 ++++++++++++- 11 files changed, 1071 insertions(+), 30 deletions(-) create mode 100644 docs/src/misc/export.md create mode 100644 notebooks/05_Export.ipynb create mode 100644 notebooks/py_scripts/05_Export.py create mode 100644 src/spyglass/utils/dj_graph.py diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 996cb36dc..920b646a7 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -51,6 +51,7 @@ nav: - Data Sync: notebooks/02_Data_Sync.ipynb - Merge Tables: notebooks/03_Merge_Tables.ipynb - Config Populate: notebooks/04_PopulateConfigFile.ipynb + - Export: notebooks/05_Export.ipynb - Spikes: - Spike Sorting V0: notebooks/10_Spike_SortingV0.ipynb - Spike Sorting V1: notebooks/10_Spike_SortingV1.ipynb @@ -75,6 +76,7 @@ nav: - Insert Data: misc/insert_data.md - Merge Tables: misc/merge_tables.md - Database Management: misc/database_management.md + - Export: misc/export.md - API Reference: api/ # defer to gen-files + literate-nav - How to Contribute: contribute.md - Change Log: CHANGELOG.md diff --git a/docs/src/misc/export.md b/docs/src/misc/export.md new file mode 100644 index 000000000..e0cea5d08 --- /dev/null +++ b/docs/src/misc/export.md @@ -0,0 +1,126 @@ +# Export Process + +## Why + +DataJoint does not have any built-in functionality for exporting vertical slices +of a database. A lab can maintain a shared DataJoint pipeline across multiple +projects, but conforming to NIH data sharing guidelines may require that data +from only one project be shared during publication. + +## Requirements + +To export data with the current implementation, you must do the following: + +- All custom tables must inherit from `SpyglassMixin` (e.g., + `class MyTable(SpyglassMixin, dj.ManualOrOther):`) +- Only one export can be active at a time. +- Start the export process with `ExportSelection.start_export()`, run all + functions associated with a given analysis, and end the export process with + `ExportSelection.end_export()`. + +## How + +The current implementation relies on two classes in the `spyglass` package: +`SpyglassMixin` and `RestrGraph` and the `Export` tables. + +- `SpyglassMixin`: See `spyglass/utils/dj_mixin.py` +- `RestrGraph`: See `spyglass/utils/dj_graph.py` +- `Export`: See `spyglass/common/common_usage.py` + +### Mixin + +The `SpyglassMixin` class is a subclass of DataJoint's `Manual` class. A subset +of methods are used to set an environment variable, `SPYGLASS_EXPORT_ID`, and, +while active, intercept all `fetch`/`fetch_nwb` calls to tables. When `fetch` is +called, the mixin grabs the table name and the restriction applied to the table +and stores them in the `ExportSelection` part tables. + +- `fetch_nwb` is specific to Spyglass and logs all analysis nwb files that are + fetched. +- `fetch` is a DataJoint method that retrieves data from a table. + +### Graph + +The `RestrGraph` class uses DataJoint's networkx graph to store each of the +tables and restrictions intercepted by the `SpyglassMixin`'s `fetch` as +'leaves'. The class then cascades these restrictions up from each leaf to all +ancestors. Use is modeled in the methods of `ExportSelection`. + +```python +from spyglass.utils.dj_graph import RestrGraph + +restr_graph = RestrGraph(seed_table=AnyTable, leaves=None, verbose=False) +restr_graph.add_leaves( + leaves=[ + { + "table_name": MyTable.full_table_name, + "restriction": "any_restriction", + }, + { + "table_name": AnotherTable.full_table_name, + "restriction": "another_restriction", + }, + ] +) +restr_graph.cascade() +restricted_leaves = restr_graph.leaf_ft +all_restricted_tables = restr_graph.all_ft + +restr_graph.write_export(paper_id="my_paper_id") # part of `populate` below +``` + +By default, a `RestrGraph` object is created with a seed table to have access to +a DataJoint connection and graph. One or more leaves can be added at +initialization or later with the `add_leaves` method. The cascade process is +delayed until `cascade`, or another method that requires the cascade, is called. + +Cascading a single leaf involves transforming the leaf's restriction into its +parent's restriction, then repeating the process until all ancestors are +reached. If two leaves share a common ancestor, the restrictions are combined. +This process also accommodates projected fields, which appear as numeric alias +nodes in the graph. + +### Export Table + +The `ExportSelection` is where users should interact with this process. + +```python +from spyglass.common.common_usage import ExportSelection +from spyglass.common.common_usage import Export + +export_key = {paper_id: "my_paper_id", analysis_id: "my_analysis_id"} +ExportSelection().start_export(**export_key) +ExportSelection().restart_export(**export_key) # to clear previous attempt +analysis_data = (MyTable & my_restr).fetch() +analysis_nwb = (MyTable & my_restr).fetch_nwb() +ExportSelection().end_export() + +# Visual inspection +touched_files = DS().list_file_paths(**export_key) +restricted_leaves = DS().preview_tables(**export_key) + +# Export +Export().populate() +``` + +`Export` will invoke `RestrGraph.write_export` to collect cascaded restrictions +and file paths in its part tables, and write out a bash script to export the +data using a series of `mysqldump` commands. The script is saved to Spyglass's +directory, `base_dir/export/paper_id/`, using credentials from `dj_config`. To +use alternative credentials, create a +[mysql config file](https://dev.mysql.com/doc/refman/8.0/en/option-files.html). + +## External Implementation + +To implement an export for a non-Spyglass database, you will need to ... + +- Create a modified version of `SpyglassMixin`, including ... + - `_export_table` method to lazy load an export table like `ExportSelection` + - `export_id` attribute, plus setter and deleter methods, to manage the status + of the export. + - `fetch` and other methods to intercept and log exported content. +- Create a modified version of `ExportSelection`, that adjusts fields like + `spyglass_version` to match the new database. + +Or, optionally, you can use the `RestrGraph` class to cascade hand-picked tables +and restrictions without the background logging of `SpyglassMixin`. diff --git a/notebooks/05_Export.ipynb b/notebooks/05_Export.ipynb new file mode 100644 index 000000000..5ab5989b7 --- /dev/null +++ b/notebooks/05_Export.ipynb @@ -0,0 +1,152 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "# Export\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Intro\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_Developer Note:_ if you may make a PR in the future, be sure to copy this\n", + "notebook, and use the `gitignore` prefix `temp` to avoid future conflicts.\n", + "\n", + "This is one notebook in a multi-part series on Spyglass.\n", + "\n", + "- To set up your Spyglass environment and database, see\n", + " [the Setup notebook](./00_Setup.ipynb)\n", + "- To insert data, see [the Insert Data notebook](./01_Insert_Data.ipynb)\n", + "- For additional info on DataJoint syntax, including table definitions and\n", + " inserts, see\n", + " [these additional tutorials](https://github.com/datajoint/datajoint-tutorials)\n", + "- For information on what's goint on behind the scenes of an export, see\n", + " [documentation](https://lorenfranklab.github.io/spyglass/0.5/misc/export/)\n", + "\n", + "In short, Spyglass offers the ability to generate exports of one or more subsets\n", + "of the database required for a specific analysis as long as you do the following:\n", + "\n", + "- Inherit `SpyglassMixin` for all custom tables.\n", + "- Run only one export at a time.\n", + "- Start and stop each export logging process.\n", + "\n", + "**NOTE:** For demonstration purposes, this notebook relies on a more populated\n", + "database to highlight restriction merging capabilities of the export process.\n", + "Adjust the restrictions to suit your own dataset.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's start by importing the `spyglass` package, along with a few others.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-01-29 16:15:00,903][INFO]: Connecting root@localhost:3309\n", + "[2024-01-29 16:15:00,912][INFO]: Connected root@localhost:3309\n" + ] + } + ], + "source": [ + "import os\n", + "import datajoint as dj\n", + "\n", + "# change to the upper level folder to detect dj_local_conf.json\n", + "if os.path.basename(os.getcwd()) == \"notebooks\":\n", + " os.chdir(\"..\")\n", + "dj.config.load(\"dj_local_conf.json\") # load config for database connection info\n", + "\n", + "# ignore datajoint+jupyter async warnings\n", + "from spyglass.common.common_usage import Export, ExportSelection\n", + "from spyglass.lfp.analysis.v1 import LFPBandV1\n", + "from spyglass.position.v1 import TrodesPosV1\n", + "from spyglass.spikesorting.v1.curation import CurationV1\n", + "\n", + "# TODO: Add commentary, describe helpers on ExportSelection\n", + "\n", + "paper_key = {\"paper_id\": \"paper1\"}\n", + "ExportSelection().start_export(**paper_key, analysis_id=\"test1\")\n", + "a = (\n", + " LFPBandV1 & \"nwb_file_name LIKE 'med%'\" & {\"filter_name\": \"Theta 5-11 Hz\"}\n", + ").fetch()\n", + "b = (\n", + " LFPBandV1\n", + " & {\n", + " \"nwb_file_name\": \"mediumnwb20230802_.nwb\",\n", + " \"filter_name\": \"Theta 5-10 Hz\",\n", + " }\n", + ").fetch()\n", + "ExportSelection().start_export(**paper_key, analysis_id=\"test2\")\n", + "c = (CurationV1 & \"curation_id = 1\").fetch_nwb()\n", + "d = (TrodesPosV1 & 'trodes_pos_params_name = \"single_led\"').fetch()\n", + "ExportSelection().stop_export()\n", + "Export().populate_paper(**paper_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Up Next\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the [next notebook](./10_Spike_Sorting.ipynb), we'll start working with\n", + "ephys data with spike sorting.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "spy", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/py_scripts/05_Export.py b/notebooks/py_scripts/05_Export.py new file mode 100644 index 000000000..4acca2335 --- /dev/null +++ b/notebooks/py_scripts/05_Export.py @@ -0,0 +1,94 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: spy +# language: python +# name: python3 +# --- + +# # Export +# + +# ## Intro +# + +# _Developer Note:_ if you may make a PR in the future, be sure to copy this +# notebook, and use the `gitignore` prefix `temp` to avoid future conflicts. +# +# This is one notebook in a multi-part series on Spyglass. +# +# - To set up your Spyglass environment and database, see +# [the Setup notebook](./00_Setup.ipynb) +# - To insert data, see [the Insert Data notebook](./01_Insert_Data.ipynb) +# - For additional info on DataJoint syntax, including table definitions and +# inserts, see +# [these additional tutorials](https://github.com/datajoint/datajoint-tutorials) +# - For information on what's goint on behind the scenes of an export, see +# [documentation](https://lorenfranklab.github.io/spyglass/0.5/misc/export/) +# +# In short, Spyglass offers the ability to generate exports of one or more subsets +# of the database required for a specific analysis as long as you do the following: +# +# - Inherit `SpyglassMixin` for all custom tables. +# - Run only one export at a time. +# - Start and stop each export logging process. +# +# **NOTE:** For demonstration purposes, this notebook relies on a more populated +# database to highlight restriction merging capabilities of the export process. +# Adjust the restrictions to suit your own dataset. +# + +# ## Imports +# + +# Let's start by importing the `spyglass` package, along with a few others. +# + +# + +import os +import datajoint as dj + +# change to the upper level folder to detect dj_local_conf.json +if os.path.basename(os.getcwd()) == "notebooks": + os.chdir("..") +dj.config.load("dj_local_conf.json") # load config for database connection info + +# ignore datajoint+jupyter async warnings +from spyglass.common.common_usage import Export, ExportSelection +from spyglass.lfp.analysis.v1 import LFPBandV1 +from spyglass.position.v1 import TrodesPosV1 +from spyglass.spikesorting.v1.curation import CurationV1 + +# TODO: Add commentary, describe helpers on ExportSelection + +paper_key = {"paper_id": "paper1"} +ExportSelection().start_export(**paper_key, analysis_id="test1") +a = ( + LFPBandV1 & "nwb_file_name LIKE 'med%'" & {"filter_name": "Theta 5-11 Hz"} +).fetch() +b = ( + LFPBandV1 + & { + "nwb_file_name": "mediumnwb20230802_.nwb", + "filter_name": "Theta 5-10 Hz", + } +).fetch() +ExportSelection().start_export(**paper_key, analysis_id="test2") +c = (CurationV1 & "curation_id = 1").fetch_nwb() +d = (TrodesPosV1 & 'trodes_pos_params_name = "single_led"').fetch() +ExportSelection().stop_export() +Export().populate_paper(**paper_key) +# - + +# ## Up Next +# + +# In the [next notebook](./10_Spike_Sorting.ipynb), we'll start working with +# ephys data with spike sorting. +# diff --git a/src/spyglass/common/common_lab.py b/src/spyglass/common/common_lab.py index a6a162b2b..c5a6fbc00 100644 --- a/src/spyglass/common/common_lab.py +++ b/src/spyglass/common/common_lab.py @@ -92,7 +92,7 @@ def _load_admin(cls): """Load admin list.""" cls._admin = list( (cls.LabMemberInfo & {"admin": True}).fetch("datajoint_user_name") - ) + ) + ["root"] @property def admin(cls) -> list: diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index fdf7ae99d..47b5172c9 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -6,8 +6,15 @@ plan future development of Spyglass. """ +from typing import Union + import datajoint as dj +from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile +from spyglass.utils import SpyglassMixin, logger +from spyglass.utils.dj_graph import RestrGraph +from spyglass.utils.dj_helper_fn import unique_dicts + schema = dj.schema("common_usage") @@ -37,3 +44,159 @@ class InsertError(dj.Manual): error_message: varchar(255) error_raw = null: blob """ + + +@schema +class ExportSelection(SpyglassMixin, dj.Manual): + definition = """ + export_id: int auto_increment + --- + paper_id: varchar(32) + analysis_id: varchar(32) + spyglass_version: varchar(16) + time=CURRENT_TIMESTAMP: timestamp + unique index (paper_id, analysis_id) + """ + + class Table(SpyglassMixin, dj.Part): + definition = """ + -> master + table_id: int + --- + table_name: varchar(64) + restriction: varchar(2048) + """ + + def insert1(self, key, **kwargs): + key = self._auto_increment(key, pk="table_id") + super().insert1(key, **kwargs) + + class File(SpyglassMixin, dj.Part): + definition = """ + -> master + -> AnalysisNwbfile + """ + # Note: only tracks AnalysisNwbfile. list_file_paths also grabs Nwbfile. + + def insert1_return_pk(self, key: dict, **kwargs) -> int: + """Custom insert to return export_id.""" + status = "Resuming" + if not (query := self & key): + super().insert1(key, **kwargs) + status = "Starting" + export_id = query.fetch1("export_id") + export_key = {"export_id": export_id} + if query := (Export & export_key): + query.super_delete(warn=False) + logger.info(f"{status} {export_key}") + return export_id + + def start_export(self, paper_id, analysis_id) -> None: + self._start_export(paper_id, analysis_id) + + def stop_export(self) -> None: + self._stop_export() + + # NOTE: These helpers could be moved to table below, but I think + # end users may want to use them to check what's in the export log + # before actually exporting anything, which is more associated with + # Selection + + def list_file_paths(self, key: dict) -> list[str]: + file_table = self.File & key + analysis_fp = [ + AnalysisNwbfile().get_abs_path(fname) + for fname in file_table.fetch("analysis_file_name") + ] + nwbfile_fp = [ + Nwbfile().get_abs_path(fname) + for fname in (AnalysisNwbfile * file_table).fetch("nwb_file_name") + ] + return [{"file_path": p} for p in list({*analysis_fp, *nwbfile_fp})] + + def get_restr_graph(self, key: dict) -> RestrGraph: + leaves = unique_dicts( + (self.Table & key).fetch("table_name", "restriction", as_dict=True) + ) + return RestrGraph(seed_table=self, leaves=leaves, verbose=True) + + def preview_tables(self, key: dict) -> list[dj.FreeTable]: + return self.get_restr_graph(key).leaf_ft + + def _min_export_id(self, paper_id: str) -> int: + """Return all export_ids for a paper.""" + if isinstance(paper_id, dict): + paper_id = paper_id.get("paper_id") + if not (query := self & {"paper_id": paper_id}): + return None + return min(query.fetch("export_id")) + + def paper_export_id(self, paper_id: str) -> dict: + """Return the minimum export_id for a paper, used to populate Export.""" + return {"export_id": self._min_export_id(paper_id)} + + +@schema +class Export(SpyglassMixin, dj.Computed): + definition = """ + -> ExportSelection + """ + + # In order to get a many-to-one relationship btwn Selection and Export, + # we ignore all but the first export_id. + + class Table(SpyglassMixin, dj.Part): + definition = """ + -> master + table_id: int + --- + table_name: varchar(64) + restriction: varchar(2048) + unique index (table_name) + """ + + class File(SpyglassMixin, dj.Part): + definition = """ + -> master + file_id: int + --- + file_path: varchar(255) + """ + # What's needed? full path? relative path? + + def populate_paper(self, paper_id: Union[str, dict]): + if isinstance(paper_id, dict): + paper_id = paper_id.get("paper_id") + self.populate(ExportSelection().paper_export_id(paper_id)) + + def make(self, key): + query = ExportSelection & key + paper_key = query.fetch("paper_id", as_dict=True)[0] + + # Null insertion if export_id is not the minimum for the paper + min_export_id = query._min_export_id(paper_key) + if key.get("export_id") != min_export_id: + logger.info( + f"Skipping export_id {key['export_id']}, use {min_export_id}" + ) + self.insert1(key) + return + + restr_graph = query.get_restr_graph(paper_key) + file_paths = query.list_file_paths(paper_key) + + table_inserts = [ + {**key, **rd, "table_id": i} + for i, rd in enumerate(restr_graph.as_dict) + ] + file_inserts = [ + {**key, **fp, "file_id": i} for i, fp in enumerate(file_paths) + ] + + # Writes but does not run mysqldump. Assumes single version per paper. + version_key = query.fetch("spyglass_version", as_dict=True)[0] + restr_graph.write_export(**paper_key, **version_key) + + self.insert1(key) + self.Table().insert(table_inserts) # TODO: Duplicate error?? + self.File().insert(file_inserts) diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index af16e688d..4900f595d 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -68,6 +68,7 @@ def __init__(self, base_dir: str = None, **kwargs): "waveforms": "waveforms", "temp": "tmp", "video": "video", + "export": "export", }, "kachery": { "cloud": ".kachery-cloud", @@ -459,6 +460,7 @@ def _dj_custom(self) -> dict: "waveforms": self.waveforms_dir, "temp": self.temp_dir, "video": self.video_dir, + "export": self.export_dir, }, "kachery_dirs": { "cloud": self.config.get( @@ -516,6 +518,10 @@ def temp_dir(self) -> str: def video_dir(self) -> str: return self.config.get(self.dir_to_var("video")) + @property + def export_dir(self) -> str: + return self.config.get(self.dir_to_var("export")) + @property def debug_mode(self) -> bool: """Returns True if debug_mode is set. @@ -560,6 +566,7 @@ def dlc_output_dir(self) -> str: sorting_dir = None waveforms_dir = None video_dir = None + export_dir = None dlc_project_dir = None dlc_video_dir = None dlc_output_dir = None @@ -573,6 +580,7 @@ def dlc_output_dir(self) -> str: sorting_dir = sg_config.sorting_dir waveforms_dir = sg_config.waveforms_dir video_dir = sg_config.video_dir + export_dir = sg_config.export_dir debug_mode = sg_config.debug_mode test_mode = sg_config.test_mode prepopulate = config.get("prepopulate", False) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py new file mode 100644 index 000000000..5507eefe5 --- /dev/null +++ b/src/spyglass/utils/dj_graph.py @@ -0,0 +1,342 @@ +"""DataJoint graph traversal and restriction application. + +Note: read `ft` as FreeTable and `restr` as restriction.""" + +from pathlib import Path +from typing import Dict, List + +from datajoint import FreeTable +from datajoint import config as dj_config +from datajoint.condition import make_condition +from datajoint.table import Table + +from spyglass.settings import export_dir +from spyglass.utils import logger +from spyglass.utils.dj_helper_fn import unique_dicts + + +class RestrGraph: + def __init__( + self, + seed_table: Table, + table_name: str = None, + restriction: str = None, + leaves: List[Dict[str, str]] = None, + verbose: bool = False, + **kwargs, + ): + """Use graph to cascade restrictions up from leaves to all ancestors. + + Parameters + ---------- + seed_table : Table + Table to use to establish connection and graph + table_name : str, optional + Table name of single leaf, default None + restriction : str, optional + Restriction to apply to leaf. default None + leaves : Dict[str, str], optional + List of dictionaries with keys table_name and restriction. One + entry per leaf node. Default None. + verbose : bool, optional + Whether to print verbose output. Default False + """ + + self.connection = seed_table.connection + self.graph = seed_table.connection.dependencies + self.graph.load() + + self.verbose = verbose + self.cascaded = False + self.ancestors = set() + self.visited = set() + self.leaves = set() + + if table_name and restriction: + self.add_leaf(table_name, restriction) + if leaves: + self.add_leaves(leaves) + + def __repr__(self): + l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" + processed = "Processed" if self.cascaded else "Unprocessed" + return f"{processed} RestrictionGraph(\n\t{l_str})" + + @property + def all_ft(self): + """Get restricted FreeTables from all visited nodes.""" + return [self._get_ft(table, with_restr=True) for table in self.visited] + + @property + def leaf_ft(self): + """Get restricted FreeTables from graph leaves.""" + return [self._get_ft(table, with_restr=True) for table in self.leaves] + + def _get_node(self, table): + """Get node from graph.""" + node = self.graph.nodes.get(table) + if not node: + raise ValueError( + f"Table {table} not found in graph." + + "\n\tPlease import this table and rerun" + ) + return node + + def _set_node(self, table, attr="ft", value=None): + """Set attribute on graph node.""" + _ = self._get_node(table) # Ensure node exists + self.graph.nodes[table][attr] = value + + def _get_ft(self, table, with_restr=False): + """Get FreeTable from graph node. If one doesn't exist, create it.""" + restr = self._get_restr(table) if with_restr else True + if ft := self._get_node(table).get("ft"): + return ft & restr + ft = FreeTable(self.connection, table) + self._set_node(table, "ft", ft) + return ft & restr + + def _get_restr(self, table): + """Get restriction from graph node.""" + table = table if isinstance(table, str) else table.full_table_name + return self._get_node(table).get("restr") + + def _set_restr(self, table, restriction): + """Add restriction to graph node. If one exists, merge with new.""" + ft = self._get_ft(table) + restriction = ( # Convert to condition if list or dict + make_condition(ft, restriction, set()) + if not isinstance(restriction, str) + else restriction + ) + if existing := self._get_restr(table): + if existing == restriction: + return + join = ft & [existing, restriction] + if len(join) == len(ft & existing): + return # restriction is a subset of existing + restriction = unique_dicts(join.fetch("KEY", as_dict=True)) + + self._log_truncate( + f"Set Restr {table}: {type(restriction)} {restriction}" + ) + self._set_node(table, "restr", restriction) + + def _log_truncate(self, log_str, max_len=80): + """Truncate log lines to max_len and print if verbose.""" + if not self.verbose: + return + logger.info( + log_str[:max_len] + "..." if len(log_str) > max_len else log_str + ) + + def _child_to_parent( + self, + child, + parent, + restriction, + attr_map=None, + primary=True, + **kwargs, + ) -> List[Dict[str, str]]: + """Given a child, child's restr, and parent, get parent's restr. + + Parameters + ---------- + child : str + child table name + parent : str + parent table name + restriction : str + restriction to apply to child + attr_map : dict, optional + dictionary mapping aliases across parend/child, as pulled from + DataJoint-assembled graph. Default None. Func will flip this dict + to convert from child to parent fields. + primary : bool, optional + Is parent in child's primary key? Default True. Also derived from + DataJoint-assembled graph. If True, project only primary key fields + to avoid secondary key collisions. + + Returns + ------- + List[Dict[str, str]] + List of dicts containing primary key fields for restricted parent + table. + """ + + # Need to flip attr_map to respect parent's fields + attr_map = ( + {v: k for k, v in attr_map.items() if k != k} if attr_map else {} + ) + child_ft = self._get_ft(child) + parent_ft = self._get_ft(parent).proj() + restr = restriction or self._get_restr(child_ft) or True + restr_child = child_ft & restr + + if primary: # Project only primary key fields to avoid collisions + join = restr_child.proj(**attr_map) * parent_ft + else: # Include all fields + join = restr_child.proj(..., **attr_map) * parent_ft + + ret = unique_dicts(join.fetch(*parent_ft.primary_key, as_dict=True)) + + if len(ret) == len(parent_ft): + self._log_truncate(f"NULL rest {parent}") + + return ret + + def cascade1(self, table, restriction): + """Cascade a restriction up the graph, recursively on parents. + + Parameters + ---------- + table : str + table name + restriction : str + restriction to apply + """ + self._set_restr(table, restriction) + self.visited.add(table) + + for parent, data in self.graph.parents(table).items(): + if parent in self.visited: + continue + + if parent.isnumeric(): + parent, data = self.graph.parents(parent).popitem() + + parent_restr = self._child_to_parent( + child=table, + parent=parent, + restriction=restriction, + **data, + ) + + self.cascade1(parent, parent_restr) # Parent set on recursion + + def cascade(self) -> None: + """Cascade all restrictions up the graph.""" + for table in self.leaves - self.visited: + restr = self._get_restr(table) + self._log_truncate(f"Start {table}: {restr}") + self.cascade1(table, restr) + if not self.visited == self.ancestors: + raise RuntimeError("Cascade: FAIL - incomplete cascade") + + self.cascaded = True + + def add_leaf(self, table_name, restriction, cascade=False) -> None: + """Add leaf to graph and cascade if requested. + + Parameters + ---------- + table_name : str + table name of leaf + restriction : str + restriction to apply to leaf + """ + new_ancestors = set(self._get_ft(table_name).ancestors()) + self.ancestors |= new_ancestors # Add to total ancestors + self.visited -= new_ancestors # Remove from visited to revisit + + self.leaves.add(table_name) + self._set_restr(table_name, restriction) # Redundant if cascaded + + if cascade: + self.cascade1(table_name, restriction) + self.cascaded = True + + def add_leaves(self, leaves: List[Dict[str, str]], cascade=False) -> None: + """Add leaves to graph and cascade if requested. + + Parameters + ---------- + leaves : List[Dict[str, str]] + list of dictionaries containing table_name and restriction + cascade : bool, optional + Whether to cascade the restrictions up the graph. Default False + """ + + if not leaves: + return + if not isinstance(leaves, list): + leaves = [leaves] + leaves = unique_dicts(leaves) + for leaf in leaves: + if not ( + (table_name := leaf.get("table_name")) + and (restriction := leaf.get("restriction")) + ): + raise ValueError( + f"Leaf must have table_name and restriction: {leaf}" + ) + self.add_leaf(table_name, restriction, cascade=False) + if cascade: + self.cascade() + + @property + def as_dict(self) -> List[Dict[str, str]]: + """Return as a list of dictionaries of table_name: restriction""" + if not self.cascaded: + self.cascade() + return [ + {"table_name": table, "restriction": self._get_restr(table)} + for table in self.ancestors + if self._get_restr(table) + ] + + def _write_sql_cnf(self): + """Write SQL cnf file to avoid password prompt.""" + cnf_path = Path("~/.my.cnf").expanduser() + + if cnf_path.exists(): + return + + with open(str(cnf_path), "w") as file: + file.write( + "[client]\n" + + "user={}\n".format(dj_config["database.user"]) + + "password={}\n".format(dj_config["database.password"]) + + "host={}\n".format(dj_config["database.host"]) + ) + + def _write_mysqldump( + self, paper_id: str, docker_id=None, spyglass_version=None + ): + """Write mysqlmdump to a temporary file and return the file object""" + paper_dir = Path(export_dir) / paper_id + paper_dir.mkdir(exist_ok=True) + + dump_script = paper_dir / f"_ExportSQL_{paper_id}.sh" + dump_content = paper_dir / f"_Populate_{paper_id}.sql" + + prefix = f"docker exec -i {docker_id}" if docker_id else "" + version = ( # Include spyglass version as comment in dump + f"echo '-- SPYGLASS VERSION: {spyglass_version}' > {dump_content}\n" + if spyglass_version + else "" + ) + + with open(dump_script, "w") as file: + file.write(f"#!/bin/bash\n{version}") + + for table in self.all_ft: + if not (where := table.where_clause()): + continue + database, table_name = table.full_table_name.split(".") + file.write( + f"{prefix}mysqldump {database} {table_name} " + + f'--where="{where}" >> {dump_content}\n' + ) + logger.info(f"Export script written to {dump_script}") + + def write_export( + self, paper_id: str, docker_id=None, spyglass_version=None + ): + if not self.cascaded: + self.cascade() + self._write_sql_cnf() + self._write_mysqldump(paper_id, docker_id, spyglass_version) + + # TODO: export conda env diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 4a0495778..de9f2d3a9 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -11,6 +11,11 @@ from spyglass.utils.nwb_helper_fn import get_nwb_file +def unique_dicts(list_of_dict): + """Remove duplicate dictionaries from a list.""" + return [dict(t) for t in {tuple(d.items()) for d in list_of_dict}] + + def deprecated_factory(classes: list, old_module: str = "") -> list: """Creates a list of classes and logs a warning when instantiated @@ -127,33 +132,30 @@ def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs): nwb_objects : list List of dicts containing fetch results and NWB objects. """ - kwargs["as_dict"] = True # force return as dictionary - tbl, attr_name = nwb_master - - if not attrs: - attrs = query_expression.heading.names - - # get the list of analysis or nwb files - file_name_str = ( - "analysis_file_name" if "analysis" in nwb_master[1] else "nwb_file_name" - ) # TODO: avoid this import? from ..common.common_nwbfile import AnalysisNwbfile, Nwbfile - file_path_fn = ( - AnalysisNwbfile.get_abs_path - if "analysis" in nwb_master[1] - else Nwbfile.get_abs_path - ) + kwargs["as_dict"] = True # force return as dictionary + attrs = attrs or query_expression.heading.names # if none, all + + tbl, attr_name = nwb_master + + which = "analysis" if "analysis" in attr_name else "nwb" + tbl_map = { # map to file_name_str and file_path_fn + "analysis": ["analysis_file_name", AnalysisNwbfile.get_abs_path], + "nwb": ["nwb_file_name", Nwbfile.get_abs_path], + } + file_name_str, file_path_fn = tbl_map[which] # TODO: check that the query_expression restricts tbl - CBroz nwb_files = ( query_expression * tbl.proj(nwb2load_filepath=attr_name) ).fetch(file_name_str) + for file_name in nwb_files: file_path = file_path_fn(file_name) - if not os.path.exists(file_path): - # retrieve the file from kachery. This also opens the file and stores the file object + if not os.path.exists(file_path): # retrieve the file from kachery. + # This also opens the file and stores the file object get_nwb_file(file_path) rec_dicts = ( diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 37dc62fe0..e7ba23c72 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -784,13 +784,14 @@ def delete(self, force_permission=False, *args, **kwargs): ): part.delete(force_permission=force_permission, *args, **kwargs) - def super_delete(self, *args, **kwargs): + def super_delete(self, warn=True, *args, **kwargs): """Alias for datajoint.table.Table.delete. - Added to support MRO of SpyglassMixin""" - logger.warning("!! Using super_delete. Bypassing cautious_delete !!") - - self._log_use(start=time(), super_delete=True) + Added to support MRO of SpyglassMixin + """ + if warn: + logger.warning("!! Bypassing cautious_delete !!") + self._log_use(start=time(), super_delete=True) super().delete(*args, **kwargs) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 29978ae88..96555f2fe 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,9 +1,13 @@ +from atexit import register as exit_register +from atexit import unregister as exit_unregister from collections import OrderedDict from functools import cached_property +from os import environ from time import time from typing import Dict, List, Union import datajoint as dj +from datajoint.condition import make_condition from datajoint.errors import DataJointError from datajoint.expression import QueryExpression from datajoint.logging import logger as dj_logger @@ -17,6 +21,8 @@ from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK from spyglass.utils.logging import logger +EXPORT_ENV_VAR = "SPYGLASS_EXPORT_ID" + class SpyglassMixin: """Mixin for Spyglass DataJoint tables. @@ -119,7 +125,27 @@ def fetch_nwb(self, *attrs, **kwargs): AnalysisNwbfile (i.e., "-> (Analysis)Nwbfile" in definition) or a _nwb_table attribute. If both are present, the attribute takes precedence. + + Additional logic support Export table logging. """ + table, tbl_attr = self._nwb_table_tuple + if self.export_id and "analysis" in tbl_attr: + logger.info(f"Export {self.export_id}: fetch_nwb {self.table_name}") + tbl_pk = "analysis_file_name" + fname = (self * table).fetch1(tbl_pk) + self._export_table.File.insert1( + {"export_id": self.export_id, tbl_pk: fname}, + skip_duplicates=True, + ) + self._export_table.Table.insert1( + dict( + export_id=self.export_id, + table_name=self.full_table_name, + restriction=make_condition(self, self.restriction, set()), + ), + skip_duplicates=True, + ) + return fetch_nwb(self, self._nwb_table_tuple, *attrs, **kwargs) # ------------------------ delete_downstream_merge ------------------------ @@ -318,9 +344,7 @@ def _get_exp_summary(self): empty_pk = {self._member_pk: "NULL"} format = dj.U(self._session_pk, self._member_pk) - sess_link = self._session_connection.join( - self.restriction, reverse_order=True - ) + sess_link = self._session_connection.join(self.restriction) exp_missing = format & (sess_link - SesExp).proj(**empty_pk) exp_present = format & (sess_link * SesExp - exp_missing).proj() @@ -330,7 +354,9 @@ def _get_exp_summary(self): @cached_property def _session_connection(self) -> Union[TableChain, bool]: """Path from Session table to self. False if no connection found.""" - connection = TableChain(parent=self._delete_deps[-1], child=self) + connection = TableChain( + parent=self._delete_deps[-1], child=self, reverse=True + ) return connection if connection.has_link else False @cached_property @@ -488,8 +514,133 @@ def delete(self, force_permission=False, *args, **kwargs): """Alias for cautious_delete, overwrites datajoint.table.Table.delete""" self.cautious_delete(force_permission=force_permission, *args, **kwargs) - def super_delete(self, *args, **kwargs): + def super_delete(self, warn=True, *args, **kwargs): """Alias for datajoint.table.Table.delete.""" - logger.warning("!! Using super_delete. Bypassing cautious_delete !!") - self._log_use(start=time(), super_delete=True) + if warn: + logger.warning("!! Bypassing cautious_delete !!") + self._log_use(start=time(), super_delete=True) super().delete(*args, **kwargs) + + # ------------------------------- Export Log ------------------------------- + + @cached_property + def _spyglass_version(self): + """Get Spyglass version from dj.config.""" + from spyglass import __version__ as sg_version + + return ".".join(sg_version.split(".")[:3]) # Major.Minor.Patch + + @cached_property + def _export_table(self): + """Lazy load export selection table.""" + from spyglass.common.common_usage import ExportSelection + + return ExportSelection() + + @property + def export_id(self): + """ID of export in progress. + + NOTE: User of an env variable to store export_id may not be thread safe. + Exports must be run in sequence, not parallel. + """ + + return int(environ.get(EXPORT_ENV_VAR, 0)) + + @export_id.setter + def export_id(self, value): + """Set ID of export using `table.export_id = X` notation.""" + if self.export_id != 0 and self.export_id != value: + raise RuntimeError("Export already in progress.") + environ[EXPORT_ENV_VAR] = str(value) + exit_register(self._export_id_cleanup) # End export on exit + + @export_id.deleter + def export_id(self): + """Delete ID of export using `del table.export_id` notation.""" + self._export_id_cleanup() + + def _export_id_cleanup(self): + """Cleanup export ID.""" + if environ.get(EXPORT_ENV_VAR): + del environ[EXPORT_ENV_VAR] + exit_unregister(self._export_id_cleanup) # Remove exit hook + + def _start_export(self, paper_id, analysis_id): + """Start export process.""" + if self.export_id: + logger.info(f"Export {self.export_id} in progress. Starting new.") + self._stop_export(warn=False) + + self.export_id = self._export_table.insert1_return_pk( + dict( + paper_id=paper_id, + analysis_id=analysis_id, + spyglass_version=self._spyglass_version, + ) + ) + + def _stop_export(self, warn=True): + """End export process.""" + if not self.export_id and warn: + logger.warning("Export not in progress.") + del self.export_id + + def _log_fetch(self): + """Log fetch for export.""" + if ( + not self.export_id + or self.full_table_name == self._export_table.full_table_name + or "dandi_export" in self.full_table_name # for populated table + ): + return + logger.info(f"Export {self.export_id}: fetch() {self.table_name}") + restr_str = make_condition(self, self.restriction, set()) + if isinstance(restr_str, str) and len(restr_str) > 2048: + raise RuntimeError( + "DandiExport cannot handle restrictions > 2048.\n\t" + + "If required, please open an issue on GitHub.\n\t" + + f"Restriction: {restr_str}" + ) + self._export_table.Table.insert1( + dict( + export_id=self.export_id, + table_name=self.full_table_name, + restriction=make_condition(self, restr_str, set()), + ), + skip_duplicates=True, + ) + + def fetch(self, *args, **kwargs): + """Log fetch for export.""" + ret = super().fetch(*args, **kwargs) + self._log_fetch() + return ret + + def fetch1(self, *args, **kwargs): + """Log fetch1 for export.""" + ret = super().fetch1(*args, **kwargs) + self._log_fetch() + return ret + + # ------------------------- Other helper methods ------------------------- + + def _auto_increment(self, key, pk, *args, **kwargs): + """Auto-increment primary key.""" + if not key.get(pk): + key[pk] = (dj.U().aggr(self, n=f"max({pk})").fetch1("n") or 0) + 1 + return key + + def file_like(self, name=None, **kwargs): + """Convenience method for wildcard search on file name fields.""" + if not name: + return self & True + attr = None + for field in self.heading.names: + if "file" in field: + attr = field + break + if not attr: + logger.error(f"No file-like field found in {self.full_table_name}") + return + return self & f"{attr} LIKE '%{name}%'" From d0e17027c1f60e12b7f66fbb29cd56747efa1c2b Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 26 Mar 2024 17:05:11 -0500 Subject: [PATCH 02/15] WIP: revise doc --- docs/src/misc/export.md | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/docs/src/misc/export.md b/docs/src/misc/export.md index e0cea5d08..e52ebff81 100644 --- a/docs/src/misc/export.md +++ b/docs/src/misc/export.md @@ -20,8 +20,8 @@ To export data with the current implementation, you must do the following: ## How -The current implementation relies on two classes in the `spyglass` package: -`SpyglassMixin` and `RestrGraph` and the `Export` tables. +The current implementation relies on two classes in the Spyglass package +(`SpyglassMixin` and `RestrGraph`) and the `Export` tables. - `SpyglassMixin`: See `spyglass/utils/dj_mixin.py` - `RestrGraph`: See `spyglass/utils/dj_graph.py` @@ -29,8 +29,8 @@ The current implementation relies on two classes in the `spyglass` package: ### Mixin -The `SpyglassMixin` class is a subclass of DataJoint's `Manual` class. A subset -of methods are used to set an environment variable, `SPYGLASS_EXPORT_ID`, and, +The `SpyglassMixin` class adds functionality to DataJoint tables. A subset of +methods are used to set an environment variable, `SPYGLASS_EXPORT_ID`, and, while active, intercept all `fetch`/`fetch_nwb` calls to tables. When `fetch` is called, the mixin grabs the table name and the restriction applied to the table and stores them in the `ExportSelection` part tables. @@ -90,17 +90,16 @@ from spyglass.common.common_usage import Export export_key = {paper_id: "my_paper_id", analysis_id: "my_analysis_id"} ExportSelection().start_export(**export_key) -ExportSelection().restart_export(**export_key) # to clear previous attempt analysis_data = (MyTable & my_restr).fetch() analysis_nwb = (MyTable & my_restr).fetch_nwb() ExportSelection().end_export() # Visual inspection -touched_files = DS().list_file_paths(**export_key) -restricted_leaves = DS().preview_tables(**export_key) +touched_files = ExportSelection.list_file_paths(**export_key) +restricted_leaves = ExportSelection.preview_tables(**export_key) # Export -Export().populate() +Export().populate_paper(**export_key) ``` `Export` will invoke `RestrGraph.write_export` to collect cascaded restrictions @@ -110,6 +109,12 @@ directory, `base_dir/export/paper_id/`, using credentials from `dj_config`. To use alternative credentials, create a [mysql config file](https://dev.mysql.com/doc/refman/8.0/en/option-files.html). +To retain the abilite to delete the logging from a particular analysis, the +`export_id` is a combination of the `paper_id` and `analysis_id` in +`ExportSelection`. When populated, the `Export` table, only the minimum +`export_id` for a given `paper_id` is used, resulting in one shell script per +paper. Each shell script one `mysqldump` command per table. + ## External Implementation To implement an export for a non-Spyglass database, you will need to ... From 8565e8ab48c713c23f673c2908d86d72858bad3d Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 28 Mar 2024 15:12:06 -0500 Subject: [PATCH 03/15] =?UTF-8?q?=20=E2=9C=85=20:=20Generate=20working=20e?= =?UTF-8?q?xport=20script?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/src/misc/export.md | 4 +- src/spyglass/common/common_usage.py | 49 +++++++--- src/spyglass/utils/dj_graph.py | 133 +++++++++++++++++++++++----- src/spyglass/utils/dj_mixin.py | 4 + 4 files changed, 154 insertions(+), 36 deletions(-) diff --git a/docs/src/misc/export.md b/docs/src/misc/export.md index e52ebff81..d0b0dbff1 100644 --- a/docs/src/misc/export.md +++ b/docs/src/misc/export.md @@ -109,9 +109,9 @@ directory, `base_dir/export/paper_id/`, using credentials from `dj_config`. To use alternative credentials, create a [mysql config file](https://dev.mysql.com/doc/refman/8.0/en/option-files.html). -To retain the abilite to delete the logging from a particular analysis, the +To retain the ability to delete the logging from a particular analysis, the `export_id` is a combination of the `paper_id` and `analysis_id` in -`ExportSelection`. When populated, the `Export` table, only the minimum +`ExportSelection`. When populated, the `Export` table, only the maximum `export_id` for a given `paper_id` is used, resulting in one shell script per paper. Each shell script one `mysqldump` command per table. diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 47b5172c9..51397c4c2 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -92,9 +92,11 @@ def insert1_return_pk(self, key: dict, **kwargs) -> int: return export_id def start_export(self, paper_id, analysis_id) -> None: + """Start logging a new export.""" self._start_export(paper_id, analysis_id) def stop_export(self) -> None: + """Stop logging the current export.""" self._stop_export() # NOTE: These helpers could be moved to table below, but I think @@ -103,6 +105,7 @@ def stop_export(self) -> None: # Selection def list_file_paths(self, key: dict) -> list[str]: + """Return a list of unique file paths for a given restriction/key.""" file_table = self.File & key analysis_fp = [ AnalysisNwbfile().get_abs_path(fname) @@ -115,25 +118,36 @@ def list_file_paths(self, key: dict) -> list[str]: return [{"file_path": p} for p in list({*analysis_fp, *nwbfile_fp})] def get_restr_graph(self, key: dict) -> RestrGraph: + """Return a RestrGraph for a restriction/key's tables/restrictions. + + Ignores duplicate entries. + """ leaves = unique_dicts( (self.Table & key).fetch("table_name", "restriction", as_dict=True) ) - return RestrGraph(seed_table=self, leaves=leaves, verbose=True) + return RestrGraph(seed_table=self, leaves=leaves, verbose=False) def preview_tables(self, key: dict) -> list[dj.FreeTable]: + """Return a list of restricted FreeTables for a given restriction/key. + + Useful for checking what will be exported. + """ return self.get_restr_graph(key).leaf_ft - def _min_export_id(self, paper_id: str) -> int: - """Return all export_ids for a paper.""" + def _max_export_id(self, paper_id: str, return_all=False) -> int: + """Return last export associated with a given paper id. + + Used to populate Export table.""" if isinstance(paper_id, dict): paper_id = paper_id.get("paper_id") if not (query := self & {"paper_id": paper_id}): return None - return min(query.fetch("export_id")) + all_export_ids = query.fetch("export_id") + return all_export_ids if return_all else max(all_export_ids) def paper_export_id(self, paper_id: str) -> dict: - """Return the minimum export_id for a paper, used to populate Export.""" - return {"export_id": self._min_export_id(paper_id)} + """Return the maximum export_id for a paper, used to populate Export.""" + return {"export_id": self._max_export_id(paper_id)} @schema @@ -143,7 +157,8 @@ class Export(SpyglassMixin, dj.Computed): """ # In order to get a many-to-one relationship btwn Selection and Export, - # we ignore all but the first export_id. + # we ignore all but the last export_id. If more exports are added above, + # generating a new output will overwrite the old ones. class Table(SpyglassMixin, dj.Part): definition = """ @@ -173,14 +188,26 @@ def make(self, key): query = ExportSelection & key paper_key = query.fetch("paper_id", as_dict=True)[0] - # Null insertion if export_id is not the minimum for the paper - min_export_id = query._min_export_id(paper_key) - if key.get("export_id") != min_export_id: + # Null insertion if export_id is not the maximum for the paper + all_export_ids = query._max_export_id(paper_key, return_all=True) + max_export_id = max(all_export_ids) + if key.get("export_id") != max_export_id: logger.info( - f"Skipping export_id {key['export_id']}, use {min_export_id}" + f"Skipping export_id {key['export_id']}, use {max_export_id}" ) self.insert1(key) return + # If lesser ids are present, delete parts yielding null entries + processed_ids = set( + list(self.Table.fetch("export_id")) + + list(self.File.fetch("export_id")) + ) + if overlap := set(all_export_ids) - {max_export_id} & processed_ids: + logger.info(f"Overwriting export_ids {overlap}") + for export_id in overlap: + id_dict = {"export_id": export_id} + (self.Table & id_dict).delete_quick() + (self.Table & id_dict).delete_quick() restr_graph = query.get_restr_graph(paper_key) file_paths = query.list_file_paths(paper_key) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 5507eefe5..af3c85dd3 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -1,6 +1,7 @@ """DataJoint graph traversal and restriction application. -Note: read `ft` as FreeTable and `restr` as restriction.""" +NOTE: read `ft` as FreeTable and `restr` as restriction. +""" from pathlib import Path from typing import Dict, List @@ -115,11 +116,14 @@ def _set_restr(self, table, restriction): join = ft & [existing, restriction] if len(join) == len(ft & existing): return # restriction is a subset of existing - restriction = unique_dicts(join.fetch("KEY", as_dict=True)) + restriction = make_condition( + ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() + ) - self._log_truncate( - f"Set Restr {table}: {type(restriction)} {restriction}" - ) + if not isinstance(restriction, str): + self._log_truncate( + f"Set Restr {table}: {type(restriction)} {restriction}" + ) self._set_node(table, "restr", restriction) def _log_truncate(self, log_str, max_len=80): @@ -217,6 +221,8 @@ def cascade1(self, table, restriction): def cascade(self) -> None: """Cascade all restrictions up the graph.""" + if self.cascaded: + return for table in self.leaves - self.visited: restr = self._get_restr(table) self._log_truncate(f"Start {table}: {restr}") @@ -278,14 +284,21 @@ def add_leaves(self, leaves: List[Dict[str, str]], cascade=False) -> None: @property def as_dict(self) -> List[Dict[str, str]]: """Return as a list of dictionaries of table_name: restriction""" - if not self.cascaded: - self.cascade() + self.cascade() return [ {"table_name": table, "restriction": self._get_restr(table)} for table in self.ancestors if self._get_restr(table) ] + def _get_credentials(self): + """Get credentials for database connection.""" + return { + "user": dj_config["database.user"], + "password": dj_config["database.password"], + "host": dj_config["database.host"], + } + def _write_sql_cnf(self): """Write SQL cnf file to avoid password prompt.""" cnf_path = Path("~/.my.cnf").expanduser() @@ -293,49 +306,123 @@ def _write_sql_cnf(self): if cnf_path.exists(): return + template = "[client]\nuser={user}\npassword={password}\nhost={host}\n" + with open(str(cnf_path), "w") as file: - file.write( - "[client]\n" - + "user={}\n".format(dj_config["database.user"]) - + "password={}\n".format(dj_config["database.password"]) - + "host={}\n".format(dj_config["database.host"]) + file.write(template.format(**self._get_credentials())) + cnf_path.chmod(0o600) + + def _bash_escape(self, s): + """Escape restriction string for bash.""" + s = s.strip() + + replace_map = { + "WHERE ": "", # Remove preceding WHERE of dj.where_clause + " ": " ", # Squash double spaces + "( (": "((", # Squash double parens + ") )": ")", + '"': "'", # Replace double quotes with single + "`": "", # Remove backticks + " AND ": " \\\n\tAND ", # Add newline and tab for readability + " OR ": " \\\n\tOR ", # OR extra space to align with AND + ")AND(": ") \\\n\tAND (", + ")OR(": ") \\\n\tOR (", + } + for old, new in replace_map.items(): + s = s.replace(old, new) + if s.startswith("(((") and s.endswith(")))"): + s = s[2:-2] # Remove extra parens for readability + return s + + def _cmd_prefix(self, docker_id=None): + """Get prefix for mysqldump command. Includes docker exec if needed.""" + if not docker_id: + return "mysqldump " + return ( + f"docker exec -i {docker_id} \\\n\tmysqldump " + + "-u {user} --password={password} \\\n\t".format( + **self._get_credentials() ) + ) def _write_mysqldump( self, paper_id: str, docker_id=None, spyglass_version=None ): - """Write mysqlmdump to a temporary file and return the file object""" - paper_dir = Path(export_dir) / paper_id + """Write mysqlmdump.sh script to export data. + + Parameters + ---------- + paper_id : str + Paper ID to use for export file names + docker_id : str, optional + Docker container ID to export from. Default None + spyglass_version : str, optional + Spyglass version to include in export. Default None + """ + paper_dir = Path(export_dir) / paper_id if not docker_id else Path(".") paper_dir.mkdir(exist_ok=True) dump_script = paper_dir / f"_ExportSQL_{paper_id}.sh" dump_content = paper_dir / f"_Populate_{paper_id}.sql" - prefix = f"docker exec -i {docker_id}" if docker_id else "" + prefix = self._cmd_prefix(docker_id) version = ( # Include spyglass version as comment in dump - f"echo '-- SPYGLASS VERSION: {spyglass_version}' > {dump_content}\n" + "echo '--'\n" + + f"echo '-- SPYGLASS VERSION: {spyglass_version} --'\n" + + "echo '--'\n\n" if spyglass_version else "" ) + create_cmd = ( + "echo 'CREATE DATABASE IF NOT EXISTS {database}; " + + "USE {database};'\n\n" + ) + dump_cmd = prefix + '{database} {table} --where="\\\n\t{where}"\n\n' + + tables_by_db = sorted(self.all_ft, key=lambda x: x.full_table_name) with open(dump_script, "w") as file: - file.write(f"#!/bin/bash\n{version}") + file.write( + "#!/bin/bash\n\n" + + f"exec > {dump_content}\n\n" # Redirect output to sql file + + f"{version}" # Include spyglass version as comment + ) - for table in self.all_ft: + prev_db = None + for table in tables_by_db: if not (where := table.where_clause()): continue - database, table_name = table.full_table_name.split(".") + where = self._bash_escape(where) + database, table_name = table.full_table_name.replace( + "`", "" + ).split(".") + if database != prev_db: + file.write(create_cmd.format(database=database)) + prev_db = database file.write( - f"{prefix}mysqldump {database} {table_name} " - + f'--where="{where}" >> {dump_content}\n' + dump_cmd.format( + database=database, table=table_name, where=where + ) ) logger.info(f"Export script written to {dump_script}") def write_export( self, paper_id: str, docker_id=None, spyglass_version=None ): - if not self.cascaded: - self.cascade() + """Write export bash script for all tables in graph. + + Also writes a user-specific .my.cnf file to avoid password prompt. + + Parameters + ---------- + paper_id : str + Paper ID to use for export file names + docker_id : str, optional + Docker container ID to export from. Default None + spyglass_version : str, optional + Spyglass version to include in export. Default None + """ + self.cascade() self._write_sql_cnf() self._write_mysqldump(paper_id, docker_id, spyglass_version) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 96555f2fe..387ee4107 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -161,6 +161,10 @@ def _merge_tables(self) -> Dict[str, dj.FreeTable]: merge_tables = {} def search_descendants(parent): + # TODO: Add check that parents are in the graph. If not, raise error + # asking user to import the table. + # TODO: Make a `is_merge_table` helper, and check for false + # positives in the mixin init. for desc in parent.descendants(as_objects=True): if ( MERGE_PK not in desc.heading.names From cbc62ba554d148b7f90417e4340b985d90febd67 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 28 Mar 2024 16:33:28 -0500 Subject: [PATCH 04/15] Cleanup: Expand notebook, migrate export process from graph class to export --- notebooks/05_Export.ipynb | 668 +++++++++++++++++++++++++++- notebooks/py_scripts/05_Export.py | 129 +++++- src/spyglass/common/common_usage.py | 174 +++++++- src/spyglass/utils/dj_graph.py | 142 +----- src/spyglass/utils/dj_mixin.py | 12 +- 5 files changed, 943 insertions(+), 182 deletions(-) diff --git a/notebooks/05_Export.ipynb b/notebooks/05_Export.ipynb index 5ab5989b7..c3b56aee5 100644 --- a/notebooks/05_Export.ipynb +++ b/notebooks/05_Export.ipynb @@ -57,7 +57,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's start by importing the `spyglass` package, along with a few others.\n" + "Let's start by connecting to the database and importing some tables that might\n", + "be used in an analysis.\n" ] }, { @@ -71,8 +72,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2024-01-29 16:15:00,903][INFO]: Connecting root@localhost:3309\n", - "[2024-01-29 16:15:00,912][INFO]: Connected root@localhost:3309\n" + "[2024-03-28 16:32:49,766][INFO]: Connecting root@localhost:3309\n", + "[2024-03-28 16:32:49,773][INFO]: Connected root@localhost:3309\n", + "/home/cb/miniconda3/envs/spy/lib/python3.9/site-packages/torch/cuda/__init__.py:83: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)\n", + " return torch._C._cuda_getDeviceCount() > 0\n" ] } ], @@ -85,33 +88,666 @@ " os.chdir(\"..\")\n", "dj.config.load(\"dj_local_conf.json\") # load config for database connection info\n", "\n", - "# ignore datajoint+jupyter async warnings\n", "from spyglass.common.common_usage import Export, ExportSelection\n", "from spyglass.lfp.analysis.v1 import LFPBandV1\n", "from spyglass.position.v1 import TrodesPosV1\n", - "from spyglass.spikesorting.v1.curation import CurationV1\n", + "from spyglass.spikesorting.v1.curation import CurationV1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Export Tables\n", + "\n", + "The `ExportSelection` table will populate while we conduct the analysis. For\n", + "each file opened and each `fetch` call, an entry will be logged in one of its\n", + "part tables.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "

export_id

\n", + " \n", + "
\n", + "

paper_id

\n", + " \n", + "
\n", + "

analysis_id

\n", + " \n", + "
\n", + "

spyglass_version

\n", + " \n", + "
\n", + "

time

\n", + " \n", + "
\n", + " \n", + "

Total: 0

\n", + " " + ], + "text/plain": [ + "*export_id paper_id analysis_id spyglass_versi time \n", + "+-----------+ +----------+ +------------+ +------------+ +------+\n", + "\n", + " (Total: 0)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ExportSelection()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "

export_id

\n", + " \n", + "
\n", + "

table_id

\n", + " \n", + "
\n", + "

table_name

\n", + " \n", + "
\n", + "

restriction

\n", + " \n", + "
\n", + " \n", + "

Total: 0

\n", + " " + ], + "text/plain": [ + "*export_id *table_id table_name restriction \n", + "+-----------+ +----------+ +------------+ +------------+\n", + "\n", + " (Total: 0)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ExportSelection.Table()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "

export_id

\n", + " \n", + "
\n", + "

analysis_file_name

\n", + " name of the file\n", + "
\n", + " \n", + "

Total: 0

\n", + " " + ], + "text/plain": [ + "*export_id *analysis_file\n", + "+-----------+ +------------+\n", + "\n", + " (Total: 0)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ExportSelection.File()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Exports are organized around paper and analysis IDs. A single export will be\n", + "generated for each paper, but we can delete/revise logs for each analysis before\n", + "running the export. When we're ready, we can run the `populate_paper` method\n", + "of the `Export` table. By default, export logs will ignore all tables in this\n", + "`common_usage` schema.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Logging\n", + "\n", + "There are a few restrictions to keep in mind when export logging:\n", + "\n", + "- You can only run _ONE_ export at a time.\n", + "- All tables must inherit `SpyglassMixin`\n", + "\n", + "
How to inherit SpyglassMixin\n", + "\n", + "DataJoint tables all inherit from one of the built-in table types.\n", + "\n", + "```python\n", + "class MyTable(dj.Manual):\n", + " ...\n", + "```\n", "\n", - "# TODO: Add commentary, describe helpers on ExportSelection\n", + "To inherit the mixin, simply add it to the `()` of the class before the\n", + "DataJoint class. This can be done for existing tables without dropping them,\n", + "so long as the change has been made prior to export logging.\n", "\n", + "```python\n", + "from spyglass.utils import SpyglassMixin\n", + "class MyTable(SpyglassMixin, dj.Manual):\n", + " ...\n", + "```\n", + "\n", + "
\n", + "\n", + "Let's start logging for 'paper1'.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[16:32:51][INFO] Spyglass: Starting {'export_id': 1}\n" + ] + } + ], + "source": [ "paper_key = {\"paper_id\": \"paper1\"}\n", - "ExportSelection().start_export(**paper_key, analysis_id=\"test1\")\n", - "a = (\n", - " LFPBandV1 & \"nwb_file_name LIKE 'med%'\" & {\"filter_name\": \"Theta 5-11 Hz\"}\n", - ").fetch()\n", - "b = (\n", + "\n", + "ExportSelection().start_export(**paper_key, analysis_id=\"analysis1\")\n", + "my_lfp_data = (\n", + " LFPBandV1 # Logging this table\n", + " & \"nwb_file_name LIKE 'med%'\" # using a string restriction\n", + " & {\"filter_name\": \"Theta 5-11 Hz\"} # and a dictionary restriction\n", + ").fetch()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can check that it was logged. The syntax of the restriction will look\n", + "different from what we see in python, but the `preview_tables` will look\n", + "familiar.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "
\n", + "

export_id

\n", + " \n", + "
\n", + "

table_id

\n", + " \n", + "
\n", + "

table_name

\n", + " \n", + "
\n", + "

restriction

\n", + " \n", + "
11`lfp_band_v1`.`__l_f_p_band_v1` (( ((nwb_file_name LIKE 'med%%%%%%%%')))AND( ((`filter_name`=\"Theta 5-11 Hz\"))))
\n", + " \n", + "

Total: 1

\n", + " " + ], + "text/plain": [ + "*export_id *table_id table_name restriction \n", + "+-----------+ +----------+ +------------+ +------------+\n", + "1 1 `lfp_band_v1`. (( ((nwb_file\n", + " (Total: 1)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ExportSelection.Table()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And log more under the same analysis ...\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "my_other_lfp_data = (\n", " LFPBandV1\n", " & {\n", " \"nwb_file_name\": \"mediumnwb20230802_.nwb\",\n", " \"filter_name\": \"Theta 5-10 Hz\",\n", " }\n", - ").fetch()\n", - "ExportSelection().start_export(**paper_key, analysis_id=\"test2\")\n", - "c = (CurationV1 & \"curation_id = 1\").fetch_nwb()\n", - "d = (TrodesPosV1 & 'trodes_pos_params_name = \"single_led\"').fetch()\n", - "ExportSelection().stop_export()\n", + ").fetch()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since these restrictions are mutually exclusive, we can check that the will\n", + "be combined appropriately by priviewing the logged tables...\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[FreeTable(`lfp_band_v1`.`__l_f_p_band_v1`)\n", + " *lfp_merge_id *filter_name *filter_sampli *nwb_file_name *target_interv *lfp_band_samp analysis_file_ interval_list_ lfp_band_objec\n", + " +------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +------------+ +------------+\n", + " 0f3bb01e-0ef6- Theta 5-10 Hz 1000 mediumnwb20230 pos 0 valid ti 100 mediumnwb20230 pos 0 valid ti 44e38dc1-3779-\n", + " 0f3bb01e-0ef6- Theta 5-11 Hz 1000 mediumnwb20230 pos 0 valid ti 100 mediumnwb20230 pos 0 valid ti c9b93111-decb-\n", + " (Total: 2)]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ExportSelection().preview_tables(**paper_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try adding a new analysis with a fetched nwb file. Starting a new export\n", + "will stop the previous one.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[16:32:51][INFO] Spyglass: Export 1 in progress. Starting new.\n", + "[16:32:51][INFO] Spyglass: Starting {'export_id': 2}\n" + ] + } + ], + "source": [ + "ExportSelection().start_export(**paper_key, analysis_id=\"analysis2\")\n", + "curation_nwb = (CurationV1 & \"curation_id = 1\").fetch_nwb()\n", + "trodes_data = (TrodesPosV1 & 'trodes_pos_params_name = \"single_led\"').fetch()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can check that the right files were logged with the following...\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'file_path': '/home/cb/wrk/alt/data/raw/mediumnwb20230802_.nwb'},\n", + " {'file_path': '/home/cb/wrk/alt/data/analysis/mediumnwb20230802/mediumnwb20230802_ALNN6TZ4L7.nwb'}]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ExportSelection().list_file_paths(paper_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And stop the export with ...\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "ExportSelection().stop_export()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Populate\n", + "\n", + "The `Export` table has a `populate_paper` method that will generate an export\n", + "bash script for the tables required by your analysis, including all the upstream\n", + "tables you didn't directly need, like `Subject` and `Session`.\n", + "\n", + "**NOTE:** Populating the export for a given paper will overwrite any previous\n", + "runs. For example, if you ran an export, and then added a third analysis for the\n", + "same paper, generating another export will delete any existing bash script and\n", + "`Export` table entries for the previous run.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[16:32:51][INFO] Spyglass: Export script written to /home/cb/wrk/alt/data/export/paper1/_ExportSQL_paper1.sh\n" + ] + } + ], + "source": [ "Export().populate_paper(**paper_key)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default the export script will be located in an `export` folder within your\n", + "`SPYGLASS_BASE_DIR`. This default can be changed by adjusting your `dj.config`.\n", + "\n", + "Frank Lab members will need the help of a database admin (e.g., Chris) to\n", + "run the resulting bash script. The result will be a `.sql` file that anyone\n", + "can use to replicate the database entries you used in your analysis.\n" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/notebooks/py_scripts/05_Export.py b/notebooks/py_scripts/05_Export.py index 4acca2335..b180832d4 100644 --- a/notebooks/py_scripts/05_Export.py +++ b/notebooks/py_scripts/05_Export.py @@ -47,7 +47,8 @@ # ## Imports # -# Let's start by importing the `spyglass` package, along with a few others. +# Let's start by connecting to the database and importing some tables that might +# be used in an analysis. # # + @@ -59,32 +60,138 @@ os.chdir("..") dj.config.load("dj_local_conf.json") # load config for database connection info -# ignore datajoint+jupyter async warnings from spyglass.common.common_usage import Export, ExportSelection from spyglass.lfp.analysis.v1 import LFPBandV1 from spyglass.position.v1 import TrodesPosV1 from spyglass.spikesorting.v1.curation import CurationV1 -# TODO: Add commentary, describe helpers on ExportSelection +# - + +# ## Export Tables +# +# The `ExportSelection` table will populate while we conduct the analysis. For +# each file opened and each `fetch` call, an entry will be logged in one of its +# part tables. +# + +ExportSelection() + +ExportSelection.Table() + +ExportSelection.File() +# Exports are organized around paper and analysis IDs. A single export will be +# generated for each paper, but we can delete/revise logs for each analysis before +# running the export. When we're ready, we can run the `populate_paper` method +# of the `Export` table. By default, export logs will ignore all tables in this +# `common_usage` schema. +# + +# ## Logging +# +# There are a few restrictions to keep in mind when export logging: +# +# - You can only run _ONE_ export at a time. +# - All tables must inherit `SpyglassMixin` +# +#
How to inherit SpyglassMixin +# +# DataJoint tables all inherit from one of the built-in table types. +# +# ```python +# class MyTable(dj.Manual): +# ... +# ``` +# +# To inherit the mixin, simply add it to the `()` of the class before the +# DataJoint class. This can be done for existing tables without dropping them, +# so long as the change has been made prior to export logging. +# +# ```python +# from spyglass.utils import SpyglassMixin +# class MyTable(SpyglassMixin, dj.Manual): +# ... +# ``` +# +#
+# +# Let's start logging for 'paper1'. +# + +# + paper_key = {"paper_id": "paper1"} -ExportSelection().start_export(**paper_key, analysis_id="test1") -a = ( - LFPBandV1 & "nwb_file_name LIKE 'med%'" & {"filter_name": "Theta 5-11 Hz"} + +ExportSelection().start_export(**paper_key, analysis_id="analysis1") +my_lfp_data = ( + LFPBandV1 # Logging this table + & "nwb_file_name LIKE 'med%'" # using a string restriction + & {"filter_name": "Theta 5-11 Hz"} # and a dictionary restriction ).fetch() -b = ( +# - + +# We can check that it was logged. The syntax of the restriction will look +# different from what we see in python, but the `preview_tables` will look +# familiar. +# + +ExportSelection.Table() + +# And log more under the same analysis ... +# + +my_other_lfp_data = ( LFPBandV1 & { "nwb_file_name": "mediumnwb20230802_.nwb", "filter_name": "Theta 5-10 Hz", } ).fetch() -ExportSelection().start_export(**paper_key, analysis_id="test2") -c = (CurationV1 & "curation_id = 1").fetch_nwb() -d = (TrodesPosV1 & 'trodes_pos_params_name = "single_led"').fetch() + +# Since these restrictions are mutually exclusive, we can check that the will +# be combined appropriately by priviewing the logged tables... +# + +ExportSelection().preview_tables(**paper_key) + +# Let's try adding a new analysis with a fetched nwb file. Starting a new export +# will stop the previous one. +# + +ExportSelection().start_export(**paper_key, analysis_id="analysis2") +curation_nwb = (CurationV1 & "curation_id = 1").fetch_nwb() +trodes_data = (TrodesPosV1 & 'trodes_pos_params_name = "single_led"').fetch() + +# We can check that the right files were logged with the following... +# + +ExportSelection().list_file_paths(paper_key) + +# And stop the export with ... +# + ExportSelection().stop_export() + +# ## Populate +# +# The `Export` table has a `populate_paper` method that will generate an export +# bash script for the tables required by your analysis, including all the upstream +# tables you didn't directly need, like `Subject` and `Session`. +# +# **NOTE:** Populating the export for a given paper will overwrite any previous +# runs. For example, if you ran an export, and then added a third analysis for the +# same paper, generating another export will delete any existing bash script and +# `Export` table entries for the previous run. +# + Export().populate_paper(**paper_key) -# - + +# By default the export script will be located in an `export` folder within your +# `SPYGLASS_BASE_DIR`. This default can be changed by adjusting your `dj.config`. +# +# Frank Lab members will need the help of a database admin (e.g., Chris) to +# run the resulting bash script. The result will be a `.sql` file that anyone +# can use to replicate the database entries you used in your analysis. +# # ## Up Next # diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 51397c4c2..7443206a4 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -6,11 +6,15 @@ plan future development of Spyglass. """ -from typing import Union +from pathlib import Path +from typing import List, Union import datajoint as dj +from datajoint import FreeTable +from datajoint import config as dj_config from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile +from spyglass.settings import export_dir from spyglass.utils import SpyglassMixin, logger from spyglass.utils.dj_graph import RestrGraph from spyglass.utils.dj_helper_fn import unique_dicts @@ -63,7 +67,7 @@ class Table(SpyglassMixin, dj.Part): -> master table_id: int --- - table_name: varchar(64) + table_name: varchar(128) restriction: varchar(2048) """ @@ -82,6 +86,7 @@ def insert1_return_pk(self, key: dict, **kwargs) -> int: """Custom insert to return export_id.""" status = "Resuming" if not (query := self & key): + key = self._auto_increment(key, pk="export_id") super().insert1(key, **kwargs) status = "Starting" export_id = query.fetch1("export_id") @@ -95,7 +100,7 @@ def start_export(self, paper_id, analysis_id) -> None: """Start logging a new export.""" self._start_export(paper_id, analysis_id) - def stop_export(self) -> None: + def stop_export(self, **kwargs) -> None: """Stop logging the current export.""" self._stop_export() @@ -127,12 +132,12 @@ def get_restr_graph(self, key: dict) -> RestrGraph: ) return RestrGraph(seed_table=self, leaves=leaves, verbose=False) - def preview_tables(self, key: dict) -> list[dj.FreeTable]: + def preview_tables(self, **kwargs) -> list[dj.FreeTable]: """Return a list of restricted FreeTables for a given restriction/key. Useful for checking what will be exported. """ - return self.get_restr_graph(key).leaf_ft + return self.get_restr_graph(kwargs).leaf_ft def _max_export_id(self, paper_id: str, return_all=False) -> int: """Return last export associated with a given paper id. @@ -165,7 +170,7 @@ class Table(SpyglassMixin, dj.Part): -> master table_id: int --- - table_name: varchar(64) + table_name: varchar(128) restriction: varchar(2048) unique index (table_name) """ @@ -222,8 +227,161 @@ def make(self, key): # Writes but does not run mysqldump. Assumes single version per paper. version_key = query.fetch("spyglass_version", as_dict=True)[0] - restr_graph.write_export(**paper_key, **version_key) + self.write_export( + free_tables=restr_graph.all_ft, **paper_key, **version_key + ) self.insert1(key) - self.Table().insert(table_inserts) # TODO: Duplicate error?? + self.Table().insert(table_inserts) self.File().insert(file_inserts) + + def _get_credentials(self): + """Get credentials for database connection.""" + return { + "user": dj_config["database.user"], + "password": dj_config["database.password"], + "host": dj_config["database.host"], + } + + def _write_sql_cnf(self): + """Write SQL cnf file to avoid password prompt.""" + cnf_path = Path("~/.my.cnf").expanduser() + + if cnf_path.exists(): + return + + template = "[client]\nuser={user}\npassword={password}\nhost={host}\n" + + with open(str(cnf_path), "w") as file: + file.write(template.format(**self._get_credentials())) + cnf_path.chmod(0o600) + + def _bash_escape(self, s): + """Escape restriction string for bash.""" + s = s.strip() + + replace_map = { + "WHERE ": "", # Remove preceding WHERE of dj.where_clause + " ": " ", # Squash double spaces + "( (": "((", # Squash double parens + ") )": ")", + '"': "'", # Replace double quotes with single + "`": "", # Remove backticks + " AND ": " \\\n\tAND ", # Add newline and tab for readability + " OR ": " \\\n\tOR ", # OR extra space to align with AND + ")AND(": ") \\\n\tAND (", + ")OR(": ") \\\n\tOR (", + "#": "\\#", + } + for old, new in replace_map.items(): + s = s.replace(old, new) + if s.startswith("(((") and s.endswith(")))"): + s = s[2:-2] # Remove extra parens for readability + return s + + def _cmd_prefix(self, docker_id=None): + """Get prefix for mysqldump command. Includes docker exec if needed.""" + if not docker_id: + return "mysqldump " + return ( + f"docker exec -i {docker_id} \\\n\tmysqldump " + + "-u {user} --password={password} \\\n\t".format( + **self._get_credentials() + ) + ) + + def _write_mysqldump( + self, + free_tables: List[FreeTable], + paper_id: str, + docker_id=None, + spyglass_version=None, + ): + """Write mysqlmdump.sh script to export data. + + Parameters + ---------- + paper_id : str + Paper ID to use for export file names + docker_id : str, optional + Docker container ID to export from. Default None + spyglass_version : str, optional + Spyglass version to include in export. Default None + """ + paper_dir = Path(export_dir) / paper_id if not docker_id else Path(".") + paper_dir.mkdir(exist_ok=True) + + dump_script = paper_dir / f"_ExportSQL_{paper_id}.sh" + dump_content = paper_dir / f"_Populate_{paper_id}.sql" + + prefix = self._cmd_prefix(docker_id) + version = ( # Include spyglass version as comment in dump + "echo '--'\n" + + f"echo '-- SPYGLASS VERSION: {spyglass_version} --'\n" + + "echo '--'\n\n" + if spyglass_version + else "" + ) + create_cmd = ( + "echo 'CREATE DATABASE IF NOT EXISTS {database}; " + + "USE {database};'\n\n" + ) + dump_cmd = prefix + '{database} {table} --where="\\\n\t{where}"\n\n' + + tables_by_db = sorted(free_tables, key=lambda x: x.full_table_name) + + with open(dump_script, "w") as file: + file.write( + "#!/bin/bash\n\n" + + f"exec > {dump_content}\n\n" # Redirect output to sql file + + f"{version}" # Include spyglass version as comment + ) + + prev_db = None + for table in tables_by_db: + if not (where := table.where_clause()): + continue + where = self._bash_escape(where) + database, table_name = ( + table.full_table_name.replace("`", "") + .replace("#", "\\#") + .split(".") + ) + if database != prev_db: + file.write(create_cmd.format(database=database)) + prev_db = database + file.write( + dump_cmd.format( + database=database, table=table_name, where=where + ) + ) + logger.info(f"Export script written to {dump_script}") + + def write_export( + self, + free_tables: List[FreeTable], + paper_id: str, + docker_id=None, + spyglass_version=None, + ): + """Write export bash script for all tables in graph. + + Also writes a user-specific .my.cnf file to avoid password prompt. + + Parameters + ---------- + free_tables : List[FreeTable] + List of restricted FreeTables to export + paper_id : str + Paper ID to use for export file names + docker_id : str, optional + Docker container ID to export from. Default None + spyglass_version : str, optional + Spyglass version to include in export. Default None + """ + self._write_sql_cnf() + self._write_mysqldump( + free_tables, paper_id, docker_id, spyglass_version + ) + + # TODO: export conda env diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index af3c85dd3..52c42f3b3 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -3,15 +3,12 @@ NOTE: read `ft` as FreeTable and `restr` as restriction. """ -from pathlib import Path from typing import Dict, List from datajoint import FreeTable -from datajoint import config as dj_config from datajoint.condition import make_condition from datajoint.table import Table -from spyglass.settings import export_dir from spyglass.utils import logger from spyglass.utils.dj_helper_fn import unique_dicts @@ -66,6 +63,8 @@ def __repr__(self): @property def all_ft(self): """Get restricted FreeTables from all visited nodes.""" + if not self.cascaded: + self.cascade() return [self._get_ft(table, with_restr=True) for table in self.visited] @property @@ -290,140 +289,3 @@ def as_dict(self) -> List[Dict[str, str]]: for table in self.ancestors if self._get_restr(table) ] - - def _get_credentials(self): - """Get credentials for database connection.""" - return { - "user": dj_config["database.user"], - "password": dj_config["database.password"], - "host": dj_config["database.host"], - } - - def _write_sql_cnf(self): - """Write SQL cnf file to avoid password prompt.""" - cnf_path = Path("~/.my.cnf").expanduser() - - if cnf_path.exists(): - return - - template = "[client]\nuser={user}\npassword={password}\nhost={host}\n" - - with open(str(cnf_path), "w") as file: - file.write(template.format(**self._get_credentials())) - cnf_path.chmod(0o600) - - def _bash_escape(self, s): - """Escape restriction string for bash.""" - s = s.strip() - - replace_map = { - "WHERE ": "", # Remove preceding WHERE of dj.where_clause - " ": " ", # Squash double spaces - "( (": "((", # Squash double parens - ") )": ")", - '"': "'", # Replace double quotes with single - "`": "", # Remove backticks - " AND ": " \\\n\tAND ", # Add newline and tab for readability - " OR ": " \\\n\tOR ", # OR extra space to align with AND - ")AND(": ") \\\n\tAND (", - ")OR(": ") \\\n\tOR (", - } - for old, new in replace_map.items(): - s = s.replace(old, new) - if s.startswith("(((") and s.endswith(")))"): - s = s[2:-2] # Remove extra parens for readability - return s - - def _cmd_prefix(self, docker_id=None): - """Get prefix for mysqldump command. Includes docker exec if needed.""" - if not docker_id: - return "mysqldump " - return ( - f"docker exec -i {docker_id} \\\n\tmysqldump " - + "-u {user} --password={password} \\\n\t".format( - **self._get_credentials() - ) - ) - - def _write_mysqldump( - self, paper_id: str, docker_id=None, spyglass_version=None - ): - """Write mysqlmdump.sh script to export data. - - Parameters - ---------- - paper_id : str - Paper ID to use for export file names - docker_id : str, optional - Docker container ID to export from. Default None - spyglass_version : str, optional - Spyglass version to include in export. Default None - """ - paper_dir = Path(export_dir) / paper_id if not docker_id else Path(".") - paper_dir.mkdir(exist_ok=True) - - dump_script = paper_dir / f"_ExportSQL_{paper_id}.sh" - dump_content = paper_dir / f"_Populate_{paper_id}.sql" - - prefix = self._cmd_prefix(docker_id) - version = ( # Include spyglass version as comment in dump - "echo '--'\n" - + f"echo '-- SPYGLASS VERSION: {spyglass_version} --'\n" - + "echo '--'\n\n" - if spyglass_version - else "" - ) - create_cmd = ( - "echo 'CREATE DATABASE IF NOT EXISTS {database}; " - + "USE {database};'\n\n" - ) - dump_cmd = prefix + '{database} {table} --where="\\\n\t{where}"\n\n' - - tables_by_db = sorted(self.all_ft, key=lambda x: x.full_table_name) - - with open(dump_script, "w") as file: - file.write( - "#!/bin/bash\n\n" - + f"exec > {dump_content}\n\n" # Redirect output to sql file - + f"{version}" # Include spyglass version as comment - ) - - prev_db = None - for table in tables_by_db: - if not (where := table.where_clause()): - continue - where = self._bash_escape(where) - database, table_name = table.full_table_name.replace( - "`", "" - ).split(".") - if database != prev_db: - file.write(create_cmd.format(database=database)) - prev_db = database - file.write( - dump_cmd.format( - database=database, table=table_name, where=where - ) - ) - logger.info(f"Export script written to {dump_script}") - - def write_export( - self, paper_id: str, docker_id=None, spyglass_version=None - ): - """Write export bash script for all tables in graph. - - Also writes a user-specific .my.cnf file to avoid password prompt. - - Parameters - ---------- - paper_id : str - Paper ID to use for export file names - docker_id : str, optional - Docker container ID to export from. Default None - spyglass_version : str, optional - Spyglass version to include in export. Default None - """ - self.cascade() - self._write_sql_cnf() - self._write_mysqldump(paper_id, docker_id, spyglass_version) - - # TODO: export conda env diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 387ee4107..b28e7ee7a 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -130,7 +130,9 @@ def fetch_nwb(self, *attrs, **kwargs): """ table, tbl_attr = self._nwb_table_tuple if self.export_id and "analysis" in tbl_attr: - logger.info(f"Export {self.export_id}: fetch_nwb {self.table_name}") + logger.debug( + f"Export {self.export_id}: fetch_nwb {self.table_name}" + ) tbl_pk = "analysis_file_name" fname = (self * table).fetch1(tbl_pk) self._export_table.File.insert1( @@ -592,13 +594,9 @@ def _stop_export(self, warn=True): def _log_fetch(self): """Log fetch for export.""" - if ( - not self.export_id - or self.full_table_name == self._export_table.full_table_name - or "dandi_export" in self.full_table_name # for populated table - ): + if not self.export_id or self.database == "common_usage": return - logger.info(f"Export {self.export_id}: fetch() {self.table_name}") + logger.debug(f"Export {self.export_id}: fetch() {self.table_name}") restr_str = make_condition(self, self.restriction, set()) if isinstance(restr_str, str) and len(restr_str) > 2048: raise RuntimeError( From c703c77ffbf3b1f9020dc1aa1a3492600b382c74 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 28 Mar 2024 16:49:19 -0500 Subject: [PATCH 05/15] Revert dj_chains related edits --- src/spyglass/utils/dj_mixin.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index b28e7ee7a..abdefd3cb 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -350,7 +350,9 @@ def _get_exp_summary(self): empty_pk = {self._member_pk: "NULL"} format = dj.U(self._session_pk, self._member_pk) - sess_link = self._session_connection.join(self.restriction) + sess_link = self._session_connection.join( + self.restriction, reverse_order=True + ) exp_missing = format & (sess_link - SesExp).proj(**empty_pk) exp_present = format & (sess_link * SesExp - exp_missing).proj() @@ -360,9 +362,7 @@ def _get_exp_summary(self): @cached_property def _session_connection(self) -> Union[TableChain, bool]: """Path from Session table to self. False if no connection found.""" - connection = TableChain( - parent=self._delete_deps[-1], child=self, reverse=True - ) + connection = TableChain(parent=self._delete_deps[-1], child=self) return connection if connection.has_link else False @cached_property From 232523c27eda02a054fe3203edcfadb5036f3a62 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 28 Mar 2024 16:55:02 -0500 Subject: [PATCH 06/15] Update changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b4a6406f4..62ccf3129 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Refactor `TableChain` to include `_searched` attribute. #867 - Fix errors in config import #882 +- Add functionality to export vertical slice of database. #875 ## [0.5.1] (March 7, 2024) @@ -22,7 +23,7 @@ - Fixes to `_convert_mp4` #834 - Replace deprecated calls to `yaml.safe_load()` #834 - Spikesorting: - - Increase`spikeinterface` version to >=0.99.1, <0.100 #852 + - Increase`spikeinterface` version to >=0.99.1, \<0.100 #852 - Bug fix in single artifact interval edge case #859 - Bug fix in FigURL #871 - LFP @@ -199,3 +200,5 @@ [0.4.2]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.2 [0.4.3]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.4.3 [0.5.0]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.0 +[0.5.1]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.1 +[0.5.2]: https://github.com/LorenFrankLab/spyglass/releases/tag/0.5.2 From 6549b9501eabc52c26909082559cbae4a9cc7bc8 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 28 Mar 2024 17:04:41 -0500 Subject: [PATCH 07/15] Revise doc --- docs/src/misc/export.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/src/misc/export.md b/docs/src/misc/export.md index d0b0dbff1..ca3884dd7 100644 --- a/docs/src/misc/export.md +++ b/docs/src/misc/export.md @@ -65,8 +65,6 @@ restr_graph.add_leaves( restr_graph.cascade() restricted_leaves = restr_graph.leaf_ft all_restricted_tables = restr_graph.all_ft - -restr_graph.write_export(paper_id="my_paper_id") # part of `populate` below ``` By default, a `RestrGraph` object is created with a seed table to have access to @@ -102,11 +100,11 @@ restricted_leaves = ExportSelection.preview_tables(**export_key) Export().populate_paper(**export_key) ``` -`Export` will invoke `RestrGraph.write_export` to collect cascaded restrictions -and file paths in its part tables, and write out a bash script to export the -data using a series of `mysqldump` commands. The script is saved to Spyglass's -directory, `base_dir/export/paper_id/`, using credentials from `dj_config`. To -use alternative credentials, create a +`Export`'s populate will invoke the `write_export` method to collect cascaded +restrictions and file paths in its part tables, and write out a bash script to +export the data using a series of `mysqldump` commands. The script is saved to +Spyglass's directory, `base_dir/export/paper_id/`, using credentials from +`dj_config`. To use alternative credentials, create a [mysql config file](https://dev.mysql.com/doc/refman/8.0/en/option-files.html). To retain the ability to delete the logging from a particular analysis, the @@ -128,4 +126,6 @@ To implement an export for a non-Spyglass database, you will need to ... `spyglass_version` to match the new database. Or, optionally, you can use the `RestrGraph` class to cascade hand-picked tables -and restrictions without the background logging of `SpyglassMixin`. +and restrictions without the background logging of `SpyglassMixin`. The +assembled list of restricted free tables, `RestrGraph.all_ft`, can be passed to +`Export.write_export` to generate a shell script for exporting the data. From 756305021052874a07391436f6555e9a27736da1 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 2 Apr 2024 11:19:26 -0500 Subject: [PATCH 08/15] Address review comments #875 --- src/spyglass/settings.py | 1 + src/spyglass/utils/dj_graph.py | 8 +++----- src/spyglass/utils/dj_mixin.py | 26 ++++++++++++++++++-------- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 4900f595d..be2912c9d 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -60,6 +60,7 @@ def __init__(self, base_dir: str = None, **kwargs): self.relative_dirs = { # {PREFIX}_{KEY}_DIR, default dir relative to base_dir + # NOTE: Adding new dir requires edit to HHMI hub "spyglass": { "raw": "raw", "analysis": "analysis", diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 52c42f3b3..c47f527f0 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -63,8 +63,7 @@ def __repr__(self): @property def all_ft(self): """Get restricted FreeTables from all visited nodes.""" - if not self.cascaded: - self.cascade() + self.cascade() return [self._get_ft(table, with_restr=True) for table in self.visited] @property @@ -74,8 +73,7 @@ def leaf_ft(self): def _get_node(self, table): """Get node from graph.""" - node = self.graph.nodes.get(table) - if not node: + if not (node := self.graph.nodes.get(table)): raise ValueError( f"Table {table} not found in graph." + "\n\tPlease import this table and rerun" @@ -185,7 +183,7 @@ def _child_to_parent( ret = unique_dicts(join.fetch(*parent_ft.primary_key, as_dict=True)) if len(ret) == len(parent_ft): - self._log_truncate(f"NULL rest {parent}") + self._log_truncate(f"NULL restr {parent}") return ret diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index abdefd3cb..d83a9075c 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -134,9 +134,12 @@ def fetch_nwb(self, *attrs, **kwargs): f"Export {self.export_id}: fetch_nwb {self.table_name}" ) tbl_pk = "analysis_file_name" - fname = (self * table).fetch1(tbl_pk) - self._export_table.File.insert1( - {"export_id": self.export_id, tbl_pk: fname}, + fnames = (self * table).fetch(tbl_pk) + self._export_table.File.insert( + [ + {"export_id": self.export_id, tbl_pk: fname} + for fname in fnames + ], skip_duplicates=True, ) self._export_table.Table.insert1( @@ -592,15 +595,22 @@ def _stop_export(self, warn=True): logger.warning("Export not in progress.") del self.export_id - def _log_fetch(self): + def _log_fetch(self, *args, **kwargs): """Log fetch for export.""" if not self.export_id or self.database == "common_usage": return logger.debug(f"Export {self.export_id}: fetch() {self.table_name}") - restr_str = make_condition(self, self.restriction, set()) + + restr = self.restriction or True + if (limit := kwargs.get("limit")) or (offset := kwargs.get("offset")): + restr = super().fetch( # Use result as restr if limit/offset + restr, as_dict=True, limit=limit, offset=offset + ) + restr_str = make_condition(self, restr, set()) + if isinstance(restr_str, str) and len(restr_str) > 2048: raise RuntimeError( - "DandiExport cannot handle restrictions > 2048.\n\t" + "Export cannot handle restrictions > 2048.\n\t" + "If required, please open an issue on GitHub.\n\t" + f"Restriction: {restr_str}" ) @@ -616,13 +626,13 @@ def _log_fetch(self): def fetch(self, *args, **kwargs): """Log fetch for export.""" ret = super().fetch(*args, **kwargs) - self._log_fetch() + self._log_fetch(*args, **kwargs) return ret def fetch1(self, *args, **kwargs): """Log fetch1 for export.""" ret = super().fetch1(*args, **kwargs) - self._log_fetch() + self._log_fetch(*args, **kwargs) return ret # ------------------------- Other helper methods ------------------------- From a3817bd45d948f5096bcd59b583e20925169de7c Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 3 Apr 2024 15:57:58 -0500 Subject: [PATCH 09/15] Remove walrus in eval --- src/spyglass/utils/dj_mixin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 080dbe57a..8857079f2 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -644,7 +644,9 @@ def _log_fetch(self, *args, **kwargs): logger.debug(f"Export {self.export_id}: fetch() {self.table_name}") restr = self.restriction or True - if (limit := kwargs.get("limit")) or (offset := kwargs.get("offset")): + limit = kwargs.get("limit") + offset = kwargs.get("offset") + if limit or offset: restr = super().fetch( # Use result as restr if limit/offset restr, as_dict=True, limit=limit, offset=offset ) From 334e1dac6f5ddb412f808d260f6eee3c9b90d2cd Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 11 Apr 2024 16:56:37 -0700 Subject: [PATCH 10/15] prevent log on preview --- src/spyglass/utils/dj_mixin.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 8857079f2..9e7877a43 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -2,6 +2,7 @@ from atexit import unregister as exit_unregister from collections import OrderedDict from functools import cached_property +from inspect import stack as inspect_stack from os import environ from time import time from typing import Dict, List, Union @@ -21,13 +22,13 @@ from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK from spyglass.utils.logging import logger -EXPORT_ENV_VAR = "SPYGLASS_EXPORT_ID" - try: import pynapple # noqa F401 except ImportError: pynapple = None +EXPORT_ENV_VAR = "SPYGLASS_EXPORT_ID" + class SpyglassMixin: """Mixin for Spyglass DataJoint tables. @@ -641,14 +642,20 @@ def _log_fetch(self, *args, **kwargs): """Log fetch for export.""" if not self.export_id or self.database == "common_usage": return + + called = [i.function for i in inspect_stack()] + banned = ["head", "tail", "preview", "_repr_html_"] + if set(banned) & set(called): # if called by any in banned, return + return + logger.debug(f"Export {self.export_id}: fetch() {self.table_name}") restr = self.restriction or True limit = kwargs.get("limit") offset = kwargs.get("offset") - if limit or offset: - restr = super().fetch( # Use result as restr if limit/offset - restr, as_dict=True, limit=limit, offset=offset + if limit or offset: # Use result as restr if limit/offset + restr = self.restrict(restr).fetch( + log_fetch=False, as_dict=True, limit=limit, offset=offset ) restr_str = make_condition(self, restr, set()) @@ -667,16 +674,18 @@ def _log_fetch(self, *args, **kwargs): skip_duplicates=True, ) - def fetch(self, *args, **kwargs): + def fetch(self, log_fetch=True, *args, **kwargs): """Log fetch for export.""" ret = super().fetch(*args, **kwargs) - self._log_fetch(*args, **kwargs) + if log_fetch: + self._log_fetch(*args, **kwargs) return ret - def fetch1(self, *args, **kwargs): + def fetch1(self, log_fetch=True, *args, **kwargs): """Log fetch1 for export.""" ret = super().fetch1(*args, **kwargs) - self._log_fetch(*args, **kwargs) + if log_fetch: + self._log_fetch(*args, **kwargs) return ret # ------------------------- Other helper methods ------------------------- From 759dc800cb05bf6a1859c88b6dcd3a8cdf75eebb Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 15 Apr 2024 19:01:43 -0500 Subject: [PATCH 11/15] Fix arg order on fetch, iterate over restr --- src/spyglass/utils/dj_mixin.py | 43 +++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 9e7877a43..723be000a 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -5,7 +5,7 @@ from inspect import stack as inspect_stack from os import environ from time import time -from typing import Dict, List, Union +from typing import Dict, Iterable, List, Union import datajoint as dj from datajoint.condition import make_condition @@ -650,38 +650,43 @@ def _log_fetch(self, *args, **kwargs): logger.debug(f"Export {self.export_id}: fetch() {self.table_name}") - restr = self.restriction or True + restrictions = self.restriction or True limit = kwargs.get("limit") offset = kwargs.get("offset") if limit or offset: # Use result as restr if limit/offset - restr = self.restrict(restr).fetch( + restrictions = self.restrict(restrictions).fetch( log_fetch=False, as_dict=True, limit=limit, offset=offset ) - restr_str = make_condition(self, restr, set()) - if isinstance(restr_str, str) and len(restr_str) > 2048: - raise RuntimeError( - "Export cannot handle restrictions > 2048.\n\t" - + "If required, please open an issue on GitHub.\n\t" - + f"Restriction: {restr_str}" + if not isinstance(restrictions, Iterable): + restrictions = [restrictions] + + table_inserts = [] + for restr in restrictions: + restr_str = make_condition(self, restr, set()) + if isinstance(restr_str, str) and len(restr_str) > 2048: + raise RuntimeError( + "Export cannot handle restrictions > 2048.\n\t" + + "If required, please open an issue on GitHub.\n\t" + + f"Restriction: {restr_str}" + ) + table_inserts.append( + dict( + export_id=self.export_id, + table_name=self.full_table_name, + restriction=restr_str, + ) ) - self._export_table.Table.insert1( - dict( - export_id=self.export_id, - table_name=self.full_table_name, - restriction=make_condition(self, restr_str, set()), - ), - skip_duplicates=True, - ) + self._export_table.Table.insert(table_inserts, skip_duplicates=True) - def fetch(self, log_fetch=True, *args, **kwargs): + def fetch(self, *args, log_fetch=True, **kwargs): """Log fetch for export.""" ret = super().fetch(*args, **kwargs) if log_fetch: self._log_fetch(*args, **kwargs) return ret - def fetch1(self, log_fetch=True, *args, **kwargs): + def fetch1(self, *args, log_fetch=True, **kwargs): """Log fetch1 for export.""" ret = super().fetch1(*args, **kwargs) if log_fetch: From f9280ca2273f7bd6a65bfee743aa8b6a5f5fc6b7 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 16 Apr 2024 16:01:39 -0500 Subject: [PATCH 12/15] Add upstream analysis files during cascade. Address false positive fetch --- src/spyglass/common/common_usage.py | 22 +++++++-- src/spyglass/utils/dj_graph.py | 72 ++++++++++++++++++++++++++--- src/spyglass/utils/dj_mixin.py | 55 +++++++++++----------- 3 files changed, 112 insertions(+), 37 deletions(-) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 7443206a4..891c57056 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -75,6 +75,12 @@ def insert1(self, key, **kwargs): key = self._auto_increment(key, pk="table_id") super().insert1(key, **kwargs) + def insert(self, keys: List[dict], **kwargs): + if not isinstance(keys[0], dict): + raise TypeError("Pass Table Keys as list of dict") + keys = [self._auto_increment(k, pk="table_id") for k in keys] + super().insert(keys, **kwargs) + class File(SpyglassMixin, dj.Part): definition = """ -> master @@ -110,8 +116,12 @@ def stop_export(self, **kwargs) -> None: # Selection def list_file_paths(self, key: dict) -> list[str]: - """Return a list of unique file paths for a given restriction/key.""" - file_table = self.File & key + """Return a list of unique file paths for a given restriction/key. + + Note: This list reflects files fetched during the export process. For + upstream files, use RestrGraph.file_paths. + """ + file_table = self * self.File & key analysis_fp = [ AnalysisNwbfile().get_abs_path(fname) for fname in file_table.fetch("analysis_file_name") @@ -128,7 +138,9 @@ def get_restr_graph(self, key: dict) -> RestrGraph: Ignores duplicate entries. """ leaves = unique_dicts( - (self.Table & key).fetch("table_name", "restriction", as_dict=True) + (self * self.Table & key).fetch( + "table_name", "restriction", as_dict=True + ) ) return RestrGraph(seed_table=self, leaves=leaves, verbose=False) @@ -215,7 +227,9 @@ def make(self, key): (self.Table & id_dict).delete_quick() restr_graph = query.get_restr_graph(paper_key) - file_paths = query.list_file_paths(paper_key) + file_paths = unique_dicts( # Original plus upstream files + query.list_file_paths(paper_key) + restr_graph.file_paths + ) table_inserts = [ {**key, **rd, "table_id": i} diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index c47f527f0..ea5bd9e29 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -3,12 +3,13 @@ NOTE: read `ft` as FreeTable and `restr` as restriction. """ -from typing import Dict, List +from typing import Dict, List, Union from datajoint import FreeTable from datajoint.condition import make_condition from datajoint.table import Table +from spyglass.common import AnalysisNwbfile from spyglass.utils import logger from spyglass.utils.dj_helper_fn import unique_dicts @@ -49,6 +50,7 @@ def __init__( self.ancestors = set() self.visited = set() self.leaves = set() + self.analysis_pk = AnalysisNwbfile().primary_key if table_name and restriction: self.add_leaf(table_name, restriction) @@ -57,7 +59,7 @@ def __init__( def __repr__(self): l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" - processed = "Processed" if self.cascaded else "Unprocessed" + processed = "Cascaded" if self.cascaded else "Uncascaded" return f"{processed} RestrictionGraph(\n\t{l_str})" @property @@ -81,7 +83,7 @@ def _get_node(self, table): return node def _set_node(self, table, attr="ft", value=None): - """Set attribute on graph node.""" + """Set attribute on node. General helper for various attributes.""" _ = self._get_node(table) # Ensure node exists self.graph.nodes[table][attr] = value @@ -99,6 +101,16 @@ def _get_restr(self, table): table = table if isinstance(table, str) else table.full_table_name return self._get_node(table).get("restr") + def _set_files(self, table, ft, restr): + """Set node attribute for analysis files.""" + if not set(self.analysis_pk).issubset(ft.heading.names): + return + self._set_node(table, "files", (ft & restr).fetch(*self.analysis_pk)) + + def _get_files(self, table): + """Get analysis files from graph node.""" + return self._get_node(table).get("files", []) + def _set_restr(self, table, restriction): """Add restriction to graph node. If one exists, merge with new.""" ft = self._get_ft(table) @@ -117,11 +129,22 @@ def _set_restr(self, table, restriction): ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() ) - if not isinstance(restriction, str): - self._log_truncate( - f"Set Restr {table}: {type(restriction)} {restriction}" - ) self._set_node(table, "restr", restriction) + self._set_files(table, ft, restriction) + + def get_restr_ft(self, table: Union[int, str]): + """Get restricted FreeTable from graph node. + + Currently used. May be useful for debugging. + + Parameters + ---------- + table : Union[int, str] + Table name or index in visited set + """ + if isinstance(table, int): + table = list(self.visited)[table] + return self._get_ft(table) & self._get_restr(table) def _log_truncate(self, log_str, max_len=80): """Truncate log lines to max_len and print if verbose.""" @@ -287,3 +310,38 @@ def as_dict(self) -> List[Dict[str, str]]: for table in self.ancestors if self._get_restr(table) ] + + @property + def file_dict(self) -> Dict[str, List[str]]: + """Return dictionary of analysis files from all visited nodes. + + Currently unused, but could be useful for debugging. + """ + if not self.cascaded: + logger.warning("Uncascaded graph. Using leaves only.") + table_list = self.leaves + else: + table_list = self.visited + + return { + table: self._get_files(table) + for table in table_list + if any(self._get_files(table)) + } + + @property + def file_paths(self) -> List[str]: + """Return list of unique analysis files from all visited nodes. + + This covers intermediate analysis files that may not have been fetched + directly by the user. + """ + self.cascade() + unique_files = set( + [file for table in self.visited for file in self._get_files(table)] + ) + return [ + {"file_path": AnalysisNwbfile().get_abs_path(file)} + for file in unique_files + if file is not None + ] diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 723be000a..e7b528156 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -5,7 +5,7 @@ from inspect import stack as inspect_stack from os import environ from time import time -from typing import Dict, Iterable, List, Union +from typing import Dict, List, Union import datajoint as dj from datajoint.condition import make_condition @@ -136,11 +136,11 @@ def fetch_nwb(self, *attrs, **kwargs): """ table, tbl_attr = self._nwb_table_tuple if self.export_id and "analysis" in tbl_attr: - logger.debug( - f"Export {self.export_id}: fetch_nwb {self.table_name}" - ) tbl_pk = "analysis_file_name" fnames = (self * table).fetch(tbl_pk) + logger.debug( + f"Export {self.export_id}: fetch_nwb {self.table_name}, {fnames}" + ) self._export_table.File.insert( [ {"export_id": self.export_id, tbl_pk: fname} @@ -643,41 +643,44 @@ def _log_fetch(self, *args, **kwargs): if not self.export_id or self.database == "common_usage": return + banned = [ + "head", # Prevents on Table().head() call + "tail", # Prevents on Table().tail() call + "preview", # Prevents on Table() call + "_repr_html_", # Prevents on Table() call in notebook + "cautious_delete", # Prevents add on permission check during delete + "get_abs_path", # Assumes that fetch_nwb will catch file/table + ] called = [i.function for i in inspect_stack()] - banned = ["head", "tail", "preview", "_repr_html_"] if set(banned) & set(called): # if called by any in banned, return return logger.debug(f"Export {self.export_id}: fetch() {self.table_name}") - restrictions = self.restriction or True + restr = self.restriction or True limit = kwargs.get("limit") offset = kwargs.get("offset") if limit or offset: # Use result as restr if limit/offset - restrictions = self.restrict(restrictions).fetch( + restr = self.restrict(restr).fetch( log_fetch=False, as_dict=True, limit=limit, offset=offset ) - if not isinstance(restrictions, Iterable): - restrictions = [restrictions] - - table_inserts = [] - for restr in restrictions: - restr_str = make_condition(self, restr, set()) - if isinstance(restr_str, str) and len(restr_str) > 2048: - raise RuntimeError( - "Export cannot handle restrictions > 2048.\n\t" - + "If required, please open an issue on GitHub.\n\t" - + f"Restriction: {restr_str}" - ) - table_inserts.append( - dict( - export_id=self.export_id, - table_name=self.full_table_name, - restriction=restr_str, - ) + restr_str = make_condition(self, restr, set()) + + if isinstance(restr_str, str) and len(restr_str) > 2048: + raise RuntimeError( + "Export cannot handle restrictions > 2048.\n\t" + + "If required, please open an issue on GitHub.\n\t" + + f"Restriction: {restr_str}" ) - self._export_table.Table.insert(table_inserts, skip_duplicates=True) + self._export_table.Table.insert1( + dict( + export_id=self.export_id, + table_name=self.full_table_name, + restriction=restr_str, + ), + skip_duplicates=True, + ) def fetch(self, *args, log_fetch=True, **kwargs): """Log fetch for export.""" From 54274cd852d5983bb7b0ee56e5122998d324fd26 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 16 Apr 2024 16:13:54 -0500 Subject: [PATCH 13/15] Avoid regen file list on revisit node --- src/spyglass/utils/dj_graph.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index ea5bd9e29..9f1eb44c1 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -101,12 +101,6 @@ def _get_restr(self, table): table = table if isinstance(table, str) else table.full_table_name return self._get_node(table).get("restr") - def _set_files(self, table, ft, restr): - """Set node attribute for analysis files.""" - if not set(self.analysis_pk).issubset(ft.heading.names): - return - self._set_node(table, "files", (ft & restr).fetch(*self.analysis_pk)) - def _get_files(self, table): """Get analysis files from graph node.""" return self._get_node(table).get("files", []) @@ -130,7 +124,6 @@ def _set_restr(self, table, restriction): ) self._set_node(table, "restr", restriction) - self._set_files(table, ft, restriction) def get_restr_ft(self, table: Union[int, str]): """Get restricted FreeTable from graph node. @@ -210,6 +203,15 @@ def _child_to_parent( return ret + def cascade_files(self): + """Set node attribute for analysis files.""" + for table in self.visited: + ft = self._get_ft(table) + if not set(self.analysis_pk).issubset(ft.heading.names): + continue + files = (ft & self._get_restr(table)).fetch(*self.analysis_pk) + self._set_node(table, "files", files) + def cascade1(self, table, restriction): """Cascade a restriction up the graph, recursively on parents. @@ -250,6 +252,7 @@ def cascade(self) -> None: if not self.visited == self.ancestors: raise RuntimeError("Cascade: FAIL - incomplete cascade") + self.cascade_files() self.cascaded = True def add_leaf(self, table_name, restriction, cascade=False) -> None: @@ -271,6 +274,7 @@ def add_leaf(self, table_name, restriction, cascade=False) -> None: if cascade: self.cascade1(table_name, restriction) + self.cascade_files() self.cascaded = True def add_leaves(self, leaves: List[Dict[str, str]], cascade=False) -> None: @@ -300,6 +304,7 @@ def add_leaves(self, leaves: List[Dict[str, str]], cascade=False) -> None: self.add_leaf(table_name, restriction, cascade=False) if cascade: self.cascade() + self.cascade_files() @property def as_dict(self) -> List[Dict[str, str]]: From 1b59a1f1dcf0e80f70b8f3356feac3121e39a4d2 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 18 Apr 2024 13:08:32 -0500 Subject: [PATCH 14/15] Bump Export.Table.restr to mediumblob --- src/spyglass/common/common_usage.py | 15 +++++--- src/spyglass/utils/dj_graph.py | 55 ++++++++++++++++++++++------- 2 files changed, 53 insertions(+), 17 deletions(-) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 891c57056..90d3c5b3f 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -132,17 +132,24 @@ def list_file_paths(self, key: dict) -> list[str]: ] return [{"file_path": p} for p in list({*analysis_fp, *nwbfile_fp})] - def get_restr_graph(self, key: dict) -> RestrGraph: + def get_restr_graph(self, key: dict, verbose=False) -> RestrGraph: """Return a RestrGraph for a restriction/key's tables/restrictions. Ignores duplicate entries. + + Parameters + ---------- + key : dict + Any valid restriction key for ExportSelection.Table + verbose : bool, optional + Turn on RestrGraph verbosity. Default False. """ leaves = unique_dicts( (self * self.Table & key).fetch( "table_name", "restriction", as_dict=True ) ) - return RestrGraph(seed_table=self, leaves=leaves, verbose=False) + return RestrGraph(seed_table=self, leaves=leaves, verbose=verbose) def preview_tables(self, **kwargs) -> list[dj.FreeTable]: """Return a list of restricted FreeTables for a given restriction/key. @@ -183,7 +190,7 @@ class Table(SpyglassMixin, dj.Part): table_id: int --- table_name: varchar(128) - restriction: varchar(2048) + restriction: mediumblob unique index (table_name) """ @@ -194,7 +201,6 @@ class File(SpyglassMixin, dj.Part): --- file_path: varchar(255) """ - # What's needed? full path? relative path? def populate_paper(self, paper_id: Union[str, dict]): if isinstance(paper_id, dict): @@ -214,6 +220,7 @@ def make(self, key): ) self.insert1(key) return + # If lesser ids are present, delete parts yielding null entries processed_ids = set( list(self.Table.fetch("export_id")) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 9f1eb44c1..59e7497d5 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -8,6 +8,7 @@ from datajoint import FreeTable from datajoint.condition import make_condition from datajoint.table import Table +from tqdm import tqdm from spyglass.common import AnalysisNwbfile from spyglass.utils import logger @@ -55,7 +56,7 @@ def __init__( if table_name and restriction: self.add_leaf(table_name, restriction) if leaves: - self.add_leaves(leaves) + self.add_leaves(leaves, show_progress=verbose) def __repr__(self): l_str = ",\n\t".join(self.leaves) + "\n" if self.leaves else "" @@ -89,6 +90,7 @@ def _set_node(self, table, attr="ft", value=None): def _get_ft(self, table, with_restr=False): """Get FreeTable from graph node. If one doesn't exist, create it.""" + table = table if isinstance(table, str) else table.full_table_name restr = self._get_restr(table) if with_restr else True if ft := self._get_node(table).get("ft"): return ft & restr @@ -99,7 +101,7 @@ def _get_ft(self, table, with_restr=False): def _get_restr(self, table): """Get restriction from graph node.""" table = table if isinstance(table, str) else table.full_table_name - return self._get_node(table).get("restr") + return self._get_node(table).get("restr", "False") def _get_files(self, table): """Get analysis files from graph node.""" @@ -113,6 +115,7 @@ def _set_restr(self, table, restriction): if not isinstance(restriction, str) else restriction ) + # orig_restr = restriction if existing := self._get_restr(table): if existing == restriction: return @@ -123,6 +126,9 @@ def _set_restr(self, table, restriction): ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() ) + # if table == "`spikesorting_merge`.`spike_sorting_output`": + # __import__("pdb").set_trace() + self._set_node(table, "restr", restriction) def get_restr_ft(self, table: Union[int, str]): @@ -137,7 +143,7 @@ def get_restr_ft(self, table: Union[int, str]): """ if isinstance(table, int): table = list(self.visited)[table] - return self._get_ft(table) & self._get_restr(table) + return self._get_ft(table, with_restr=True) def _log_truncate(self, log_str, max_len=80): """Truncate log lines to max_len and print if verbose.""" @@ -183,8 +189,8 @@ def _child_to_parent( """ # Need to flip attr_map to respect parent's fields - attr_map = ( - {v: k for k, v in attr_map.items() if k != k} if attr_map else {} + attr_reverse = ( + {v: k for k, v in attr_map.items() if k != v} if attr_map else {} ) child_ft = self._get_ft(child) parent_ft = self._get_ft(parent).proj() @@ -192,9 +198,9 @@ def _child_to_parent( restr_child = child_ft & restr if primary: # Project only primary key fields to avoid collisions - join = restr_child.proj(**attr_map) * parent_ft + join = restr_child.proj(**attr_reverse) * parent_ft else: # Include all fields - join = restr_child.proj(..., **attr_map) * parent_ft + join = restr_child.proj(..., **attr_reverse) * parent_ft ret = unique_dicts(join.fetch(*parent_ft.primary_key, as_dict=True)) @@ -241,16 +247,30 @@ def cascade1(self, table, restriction): self.cascade1(parent, parent_restr) # Parent set on recursion - def cascade(self) -> None: - """Cascade all restrictions up the graph.""" + def cascade(self, show_progress=None) -> None: + """Cascade all restrictions up the graph. + + Parameters + ---------- + show_progress : bool, optional + Show tqdm progress bar. Default to verbose setting. + """ if self.cascaded: return - for table in self.leaves - self.visited: + to_visit = self.leaves - self.visited + for table in tqdm( + to_visit, + desc="RestrGraph: cascading restrictions", + total=len(to_visit), + disable=not (show_progress or self.verbose), + ): restr = self._get_restr(table) self._log_truncate(f"Start {table}: {restr}") self.cascade1(table, restr) if not self.visited == self.ancestors: - raise RuntimeError("Cascade: FAIL - incomplete cascade") + raise RuntimeError( + "Cascade: FAIL - incomplete cascade. Please post issue." + ) self.cascade_files() self.cascaded = True @@ -277,7 +297,9 @@ def add_leaf(self, table_name, restriction, cascade=False) -> None: self.cascade_files() self.cascaded = True - def add_leaves(self, leaves: List[Dict[str, str]], cascade=False) -> None: + def add_leaves( + self, leaves: List[Dict[str, str]], cascade=False, show_progress=None + ) -> None: """Add leaves to graph and cascade if requested. Parameters @@ -286,6 +308,8 @@ def add_leaves(self, leaves: List[Dict[str, str]], cascade=False) -> None: list of dictionaries containing table_name and restriction cascade : bool, optional Whether to cascade the restrictions up the graph. Default False + show_progress : bool, optional + Show tqdm progress bar. Default to verbose setting. """ if not leaves: @@ -293,7 +317,12 @@ def add_leaves(self, leaves: List[Dict[str, str]], cascade=False) -> None: if not isinstance(leaves, list): leaves = [leaves] leaves = unique_dicts(leaves) - for leaf in leaves: + for leaf in tqdm( + leaves, + desc="RestrGraph: adding leaves", + total=len(leaves), + disable=not (show_progress or self.verbose), + ): if not ( (table_name := leaf.get("table_name")) and (restriction := leaf.get("restriction")) From c660fb6fd90dfe91e9e9df8c43094adee0059daa Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 18 Apr 2024 15:53:51 -0500 Subject: [PATCH 15/15] Revise Export.Table uniqueness to include export_id --- src/spyglass/common/common_usage.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 90d3c5b3f..dae4f7842 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -178,6 +178,8 @@ def paper_export_id(self, paper_id: str) -> dict: class Export(SpyglassMixin, dj.Computed): definition = """ -> ExportSelection + --- + paper_id: varchar(32) """ # In order to get a many-to-one relationship btwn Selection and Export, @@ -191,7 +193,7 @@ class Table(SpyglassMixin, dj.Part): --- table_name: varchar(128) restriction: mediumblob - unique index (table_name) + unique index (export_id, table_name) """ class File(SpyglassMixin, dj.Part): @@ -252,7 +254,7 @@ def make(self, key): free_tables=restr_graph.all_ft, **paper_key, **version_key ) - self.insert1(key) + self.insert1({**key, **paper_key}) self.Table().insert(table_inserts) self.File().insert(file_inserts)