From 59686233be073f7d82c274fe92b14b99afcc7e19 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Thu, 9 Jan 2025 10:03:04 -0500 Subject: [PATCH 01/26] Fixup for some durations-as-string issues --- .../qcfractal/components/internal_jobs/socket.py | 2 +- qcfractal/qcfractal/components/serverinfo/socket.py | 2 +- qcfractal/qcfractal/config.py | 12 +++++++++--- qcfractal/qcfractal/test_config.py | 10 ++++++++-- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/qcfractal/qcfractal/components/internal_jobs/socket.py b/qcfractal/qcfractal/components/internal_jobs/socket.py index 7c4e0663f..cd09f4225 100644 --- a/qcfractal/qcfractal/components/internal_jobs/socket.py +++ b/qcfractal/qcfractal/components/internal_jobs/socket.py @@ -287,7 +287,7 @@ def delete_old_internal_jobs(self, session: Session) -> None: if self._internal_job_keep <= 0: return - before = now_at_utc() - timedelta(days=self._internal_job_keep) + before = now_at_utc() - timedelta(seconds=self._internal_job_keep) stmt = delete(InternalJobORM) stmt = stmt.where( diff --git a/qcfractal/qcfractal/components/serverinfo/socket.py b/qcfractal/qcfractal/components/serverinfo/socket.py index 60524e121..8c3ce3317 100644 --- a/qcfractal/qcfractal/components/serverinfo/socket.py +++ b/qcfractal/qcfractal/components/serverinfo/socket.py @@ -281,7 +281,7 @@ def delete_old_access_logs(self, session: Session) -> None: if self._access_log_keep <= 0 or not self._access_log_enabled: return - before = now_at_utc() - timedelta(days=self._access_log_keep) + before = now_at_utc() - timedelta(seconds=self._access_log_keep) num_deleted = self.delete_access_logs(before, session=session) self._logger.info(f"Deleted {num_deleted} access logs before {before}") diff --git a/qcfractal/qcfractal/config.py b/qcfractal/qcfractal/config.py index 5fc57adcf..be82344cf 100644 --- a/qcfractal/qcfractal/config.py +++ b/qcfractal/qcfractal/config.py @@ -380,7 +380,7 @@ class FractalConfig(ConfigBase): # Access logging log_access: bool = Field(False, description="Store API access in the database") access_log_keep: int = Field( - 0, description="How far back to keep access logs (in seconds or a string). 0 means keep all" + 0, description="How far back to keep access logs (in days or as a duration string). 0 means keep all" ) # maxmind_account_id: Optional[int] = Field(None, description="Account ID for MaxMind GeoIP2 service") @@ -403,7 +403,7 @@ class FractalConfig(ConfigBase): 1, description="Number of processes for processing internal jobs and async requests" ) internal_job_keep: int = Field( - 0, description="Number of days of finished internal job logs to keep. 0 means keep all" + 0, description="How far back to keep finished internal jobs (in days or as a duration string). 0 means keep all" ) # Homepage settings @@ -461,10 +461,16 @@ def _check_loglevel(cls, v): raise ValidationError(f"{v} is not a valid loglevel. Must be DEBUG, INFO, WARNING, ERROR, or CRITICAL") return v - @validator("service_frequency", "heartbeat_frequency", "access_log_keep", pre=True) + @validator("service_frequency", "heartbeat_frequency", pre=True) def _convert_durations(cls, v): return duration_to_seconds(v) + @validator("access_log_keep", "internal_job_keep", pre=True) + def _convert_durations_days(cls, v): + if isinstance(v, int) or (isinstance(v, str) and v.isdigit()): + return int(v) * 86400 + return duration_to_seconds(v) + class Config(ConfigCommon): env_prefix = "QCF_" diff --git a/qcfractal/qcfractal/test_config.py b/qcfractal/qcfractal/test_config.py index f507e78ab..f86bb080e 100644 --- a/qcfractal/qcfractal/test_config.py +++ b/qcfractal/qcfractal/test_config.py @@ -17,14 +17,16 @@ def test_config_durations_plain(tmp_path): base_config = copy.deepcopy(_base_config) base_config["service_frequency"] = 3600 base_config["heartbeat_frequency"] = 30 - base_config["access_log_keep"] = 100802 + base_config["access_log_keep"] = 31 + base_config["internal_job_keep"] = 7 base_config["api"]["jwt_access_token_expires"] = 7450 base_config["api"]["jwt_refresh_token_expires"] = 637277 cfg = FractalConfig(base_folder=base_folder, **base_config) assert cfg.service_frequency == 3600 assert cfg.heartbeat_frequency == 30 - assert cfg.access_log_keep == 100802 + assert cfg.access_log_keep == 2678400 # interpreted as days + assert cfg.internal_job_keep == 604800 assert cfg.api.jwt_access_token_expires == 7450 assert cfg.api.jwt_refresh_token_expires == 637277 @@ -36,6 +38,7 @@ def test_config_durations_str(tmp_path): base_config["service_frequency"] = "1h" base_config["heartbeat_frequency"] = "30s" base_config["access_log_keep"] = "1d4h2s" + base_config["internal_job_keep"] = "1d4h7s" base_config["api"]["jwt_access_token_expires"] = "2h4m10s" base_config["api"]["jwt_refresh_token_expires"] = "7d9h77s" cfg = FractalConfig(base_folder=base_folder, **base_config) @@ -43,6 +46,7 @@ def test_config_durations_str(tmp_path): assert cfg.service_frequency == 3600 assert cfg.heartbeat_frequency == 30 assert cfg.access_log_keep == 100802 + assert cfg.internal_job_keep == 100807 assert cfg.api.jwt_access_token_expires == 7450 assert cfg.api.jwt_refresh_token_expires == 637277 @@ -54,6 +58,7 @@ def test_config_durations_dhms(tmp_path): base_config["service_frequency"] = "1:00:00" base_config["heartbeat_frequency"] = "30" base_config["access_log_keep"] = "1:04:00:02" + base_config["internal_job_keep"] = "1:04:00:07" base_config["api"]["jwt_access_token_expires"] = "2:04:10" base_config["api"]["jwt_refresh_token_expires"] = "7:09:00:77" cfg = FractalConfig(base_folder=base_folder, **base_config) @@ -61,5 +66,6 @@ def test_config_durations_dhms(tmp_path): assert cfg.service_frequency == 3600 assert cfg.heartbeat_frequency == 30 assert cfg.access_log_keep == 100802 + assert cfg.internal_job_keep == 100807 assert cfg.api.jwt_access_token_expires == 7450 assert cfg.api.jwt_refresh_token_expires == 637277 From b1efeae694aa20e9acb3721bb737fb2134aeb8a5 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Thu, 2 Jan 2025 21:51:52 -0500 Subject: [PATCH 02/26] Don't attempt serialization for reponses that are already flask Response objects --- qcfractal/qcfractal/flask_app/api_v1/helpers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/qcfractal/qcfractal/flask_app/api_v1/helpers.py b/qcfractal/qcfractal/flask_app/api_v1/helpers.py index 7b983cb5b..0d9d24870 100644 --- a/qcfractal/qcfractal/flask_app/api_v1/helpers.py +++ b/qcfractal/qcfractal/flask_app/api_v1/helpers.py @@ -94,7 +94,10 @@ def wrapper(*args, **kwargs): # Now call the function, and validate the output ret = fn(*args, **kwargs) - # Serialize the output + # Serialize the output it it's not a normal flask response + if isinstance(ret, Response): + return ret + serialized = serialize(ret, accept_type) return Response(serialized, content_type=accept_type) From 2db3a87498822ff897b0417db49ea75089496e84 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Thu, 9 Jan 2025 12:39:01 -0500 Subject: [PATCH 03/26] Make records table in caches have rowid --- qcportal/qcportal/cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qcportal/qcportal/cache.py b/qcportal/qcportal/cache.py index ea2e4d826..2111204cb 100644 --- a/qcportal/qcportal/cache.py +++ b/qcportal/qcportal/cache.py @@ -74,11 +74,11 @@ def _create_tables(self): self._conn.execute( """ CREATE TABLE IF NOT EXISTS records ( - id INTEGER NOT NULL PRIMARY KEY, + id INTEGER PRIMARY KEY, status TEXT NOT NULL, modified_on DECIMAL NOT NULL, record BLOB NOT NULL - ) WITHOUT ROWID + ) """ ) From bbcd6eb97b5083f5833810d44cda72cfd6369589 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 6 Jan 2025 11:19:20 -0500 Subject: [PATCH 04/26] Expand features of internal jobs --- ...7-c13116948b54_add_progress_description.py | 28 +++++++ .../components/internal_jobs/db_models.py | 1 + .../components/internal_jobs/socket.py | 1 + .../components/internal_jobs/status.py | 10 ++- .../components/internal_jobs/test_client.py | 2 +- qcportal/qcportal/client.py | 3 +- qcportal/qcportal/internal_jobs/models.py | 81 ++++++++++++++++++- 7 files changed, 118 insertions(+), 8 deletions(-) create mode 100644 qcfractal/qcfractal/alembic/versions/2025-01-07-c13116948b54_add_progress_description.py diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-07-c13116948b54_add_progress_description.py b/qcfractal/qcfractal/alembic/versions/2025-01-07-c13116948b54_add_progress_description.py new file mode 100644 index 000000000..58d1d9149 --- /dev/null +++ b/qcfractal/qcfractal/alembic/versions/2025-01-07-c13116948b54_add_progress_description.py @@ -0,0 +1,28 @@ +"""Add progress description + +Revision ID: c13116948b54 +Revises: e798462e0c03 +Create Date: 2025-01-07 10:35:23.654928 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c13116948b54" +down_revision = "e798462e0c03" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("internal_jobs", sa.Column("progress_description", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("internal_jobs", "progress_description") + # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/components/internal_jobs/db_models.py b/qcfractal/qcfractal/components/internal_jobs/db_models.py index dc5abbbd0..1512f822d 100644 --- a/qcfractal/qcfractal/components/internal_jobs/db_models.py +++ b/qcfractal/qcfractal/components/internal_jobs/db_models.py @@ -31,6 +31,7 @@ class InternalJobORM(BaseORM): runner_uuid = Column(String) progress = Column(Integer, nullable=False, default=0) + progress_description = Column(String, nullable=True) function = Column(String, nullable=False) kwargs = Column(JSON, nullable=False) diff --git a/qcfractal/qcfractal/components/internal_jobs/socket.py b/qcfractal/qcfractal/components/internal_jobs/socket.py index cd09f4225..27964e40e 100644 --- a/qcfractal/qcfractal/components/internal_jobs/socket.py +++ b/qcfractal/qcfractal/components/internal_jobs/socket.py @@ -351,6 +351,7 @@ def _run_single(self, session: Session, job_orm: InternalJobORM, logger, job_pro if not job_progress.cancelled(): job_orm.status = InternalJobStatusEnum.complete job_orm.progress = 100 + job_orm.progress_description = "Complete" except Exception: session.rollback() diff --git a/qcfractal/qcfractal/components/internal_jobs/status.py b/qcfractal/qcfractal/components/internal_jobs/status.py index dec8a261d..6270cb813 100644 --- a/qcfractal/qcfractal/components/internal_jobs/status.py +++ b/qcfractal/qcfractal/components/internal_jobs/status.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import Optional import threading import weakref @@ -26,6 +27,8 @@ def __init__(self, job_id: int, runner_uuid: str, session: Session, update_frequ .returning(InternalJobORM.status, InternalJobORM.runner_uuid) ) self._progress = 0 + self._description = None + self._cancelled = False self._deleted = False @@ -45,7 +48,9 @@ def __init__(self, job_id: int, runner_uuid: str, session: Session, update_frequ def _update_thread(self, session: Session, end_thread: threading.Event): while True: # Update progress - stmt = self._stmt.values(progress=self._progress, last_updated=now_at_utc()) + stmt = self._stmt.values( + progress=self._progress, progress_description=self._description, last_updated=now_at_utc() + ) ret = session.execute(stmt).one_or_none() session.commit() @@ -80,8 +85,9 @@ def _stop_thread(cancel_event: threading.Event, thread: threading.Thread): def stop(self): self._finalizer() - def update_progress(self, progress: int): + def update_progress(self, progress: int, description: Optional[str] = None): self._progress = progress + self._description = description def cancelled(self) -> bool: return self._cancelled diff --git a/qcfractal/qcfractal/components/internal_jobs/test_client.py b/qcfractal/qcfractal/components/internal_jobs/test_client.py index f23aef068..3e83eb3e6 100644 --- a/qcfractal/qcfractal/components/internal_jobs/test_client.py +++ b/qcfractal/qcfractal/components/internal_jobs/test_client.py @@ -22,7 +22,7 @@ def dummmy_internal_job(self, iterations: int, session, job_progress): for i in range(iterations): time.sleep(1.0) - job_progress.update_progress(100 * ((i + 1) / iterations)) + job_progress.update_progress(100 * ((i + 1) / iterations), f"Interation {i} of {iterations}") print("Dummy internal job counter ", i) if job_progress.cancelled(): diff --git a/qcportal/qcportal/client.py b/qcportal/qcportal/client.py index ba450e3c3..24c6857cd 100644 --- a/qcportal/qcportal/client.py +++ b/qcportal/qcportal/client.py @@ -2685,7 +2685,8 @@ def get_internal_job(self, job_id: int) -> InternalJob: Gets information about an internal job on the server """ - return self.make_request("get", f"api/v1/internal_jobs/{job_id}", InternalJob) + ij_dict = self.make_request("get", f"api/v1/internal_jobs/{job_id}", Dict[str, Any]) + return InternalJob(client=self, **ij_dict) def query_internal_jobs( self, diff --git a/qcportal/qcportal/internal_jobs/models.py b/qcportal/qcportal/internal_jobs/models.py index d8c9ca672..232f6a263 100644 --- a/qcportal/qcportal/internal_jobs/models.py +++ b/qcportal/qcportal/internal_jobs/models.py @@ -1,16 +1,19 @@ +import time from datetime import datetime from enum import Enum from typing import Optional, Dict, Any, List, Union from dateutil.parser import parse as date_parser +from rich.jupyter import display try: - from pydantic.v1 import BaseModel, Extra, validator + from pydantic.v1 import BaseModel, Extra, validator, PrivateAttr except ImportError: - from pydantic import BaseModel, Extra, validator + from pydantic import BaseModel, Extra, validator, PrivateAttr from qcportal.base_models import QueryProjModelBase from ..base_models import QueryIteratorBase +from tqdm import tqdm class InternalJobStatusEnum(str, Enum): @@ -57,6 +60,7 @@ class Config: repeat_delay: Optional[int] progress: int + progress_description: Optional[str] = None function: str kwargs: Dict[str, Any] @@ -65,6 +69,73 @@ class Config: result: Any user: Optional[str] + _client: Any = PrivateAttr(None) + + def __init__(self, client=None, **kwargs): + BaseModel.__init__(self, **kwargs) + self._client = client + + def refresh(self): + """ + Updates the data of this object with information from the server + """ + + if self._client is None: + raise RuntimeError("Client is not set") + + server_data = self._client.get_internal_job(self.id) + for k, v in server_data: + setattr(self, k, v) + + def watch(self, interval: float = 2.0, timeout: Optional[float] = None): + """ + Watch an internal job for completion + + Will poll every `interval` seconds until the job is finished (complete, error, cancelled, etc). + + Parameters + ---------- + interval + Time (in seconds) between polls on the server + timeout + Max amount of time (in seconds) to wait. If None, will wait forever. + + Returns + ------- + + """ + + if self.status not in [InternalJobStatusEnum.waiting, InternalJobStatusEnum.running]: + return + + begin_time = time.time() + + end_time = None + if timeout is not None: + end_time = begin_time + timeout + + pbar = tqdm(initial=self.progress, total=100, desc=self.progress_description) + + while True: + t = time.time() + + self.refresh() + pbar.update(self.progress - pbar.n) + pbar.set_description(self.progress_description) + + if end_time is not None and t >= end_time: + raise TimeoutError("Timed out waiting for job to complete") + + if self.status not in [InternalJobStatusEnum.waiting, InternalJobStatusEnum.running]: + break + curtime = time.time() + + if end_time is not None: + # sleep the normal interval, or up to the timeout time + time.sleep(min(interval, end_time - curtime + 0.1)) + else: + time.sleep(interval) + class InternalJobQueryFilters(QueryProjModelBase): job_id: Optional[List[int]] = None @@ -118,9 +189,11 @@ def __init__(self, client, query_filters: InternalJobQueryFilters): QueryIteratorBase.__init__(self, client, query_filters, batch_limit) def _request(self) -> List[InternalJob]: - return self._client.make_request( + ij_dicts = self._client.make_request( "post", "api/v1/internal_jobs/query", - List[InternalJob], + List[Dict[str, Any]], body=self._query_filters, ) + + return [InternalJob(client=self, **ij_dict) for ij_dict in ij_dicts] From be24f3235d1a7851ede2d0649e5c0793fb734fdd Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Thu, 9 Jan 2025 09:32:38 -0500 Subject: [PATCH 05/26] Better handling of incorrect after_functions after migration --- ...e0c03_add_repeat_delay_to_internal_jobs.py | 13 +++++++++++ .../components/internal_jobs/socket.py | 23 ++++++++++++------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-02-e798462e0c03_add_repeat_delay_to_internal_jobs.py b/qcfractal/qcfractal/alembic/versions/2025-01-02-e798462e0c03_add_repeat_delay_to_internal_jobs.py index 7af6bc267..abc9d73d8 100644 --- a/qcfractal/qcfractal/alembic/versions/2025-01-02-e798462e0c03_add_repeat_delay_to_internal_jobs.py +++ b/qcfractal/qcfractal/alembic/versions/2025-01-02-e798462e0c03_add_repeat_delay_to_internal_jobs.py @@ -20,6 +20,19 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.add_column("internal_jobs", sa.Column("repeat_delay", sa.Integer(), nullable=True)) + + # Remove old periodic tasks + op.execute( + """DELETE FROM internal_jobs WHERE status IN ('waiting', 'running') AND name IN ( + 'delete_old_internal_jobs', + 'delete_old_access_log', + 'iterate_services', + 'geolocate_accesses', + 'check_manager_heartbeats', + 'update_geoip2_file' + ) + """ + ) # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/components/internal_jobs/socket.py b/qcfractal/qcfractal/components/internal_jobs/socket.py index 27964e40e..74d2969da 100644 --- a/qcfractal/qcfractal/components/internal_jobs/socket.py +++ b/qcfractal/qcfractal/components/internal_jobs/socket.py @@ -376,14 +376,21 @@ def _run_single(self, session: Session, job_orm: InternalJobORM, logger, job_pro # Run the function specified to be run after if job_orm.status == InternalJobStatusEnum.complete and job_orm.after_function is not None: - after_func_attr = attrgetter(job_orm.after_function) - after_func = after_func_attr(self.root_socket) - - after_func_params = inspect.signature(after_func).parameters - add_after_kwargs = {} - if "session" in after_func_params: - add_after_kwargs["session"] = session - after_func(**job_orm.after_function_kwargs, **add_after_kwargs) + try: + after_func_attr = attrgetter(job_orm.after_function) + after_func = after_func_attr(self.root_socket) + + after_func_params = inspect.signature(after_func).parameters + add_after_kwargs = {} + if "session" in after_func_params: + add_after_kwargs["session"] = session + after_func(**job_orm.after_function_kwargs, **add_after_kwargs) + except Exception: + # Don't rollback? not sure what to do here + result = traceback.format_exc() + logger.error(f"Job {job_orm.id} failed with exception:\n{result}") + + job_orm.status = InternalJobStatusEnum.error if job_orm.status == InternalJobStatusEnum.complete and job_orm.repeat_delay is not None: self.add( From ba2140774cb59d4bc0273bb67c584404149c4aa8 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Fri, 3 Jan 2025 15:30:38 -0500 Subject: [PATCH 06/26] Add external file capability via S3 --- ...3-02afa97249c7_add_external_files_table.py | 43 +++ .../components/external_files/__init__.py | 1 + .../components/external_files/db_models.py | 36 +++ .../components/external_files/routes.py | 12 + .../components/external_files/socket.py | 268 ++++++++++++++++++ .../qcfractal/components/register_all.py | 1 + qcfractal/qcfractal/config.py | 37 +++ qcfractal/qcfractal/db_socket/socket.py | 2 + qcportal/qcportal/external_files/__init__.py | 1 + qcportal/qcportal/external_files/models.py | 42 +++ 10 files changed, 443 insertions(+) create mode 100644 qcfractal/qcfractal/alembic/versions/2025-01-03-02afa97249c7_add_external_files_table.py create mode 100644 qcfractal/qcfractal/components/external_files/__init__.py create mode 100644 qcfractal/qcfractal/components/external_files/db_models.py create mode 100644 qcfractal/qcfractal/components/external_files/routes.py create mode 100644 qcfractal/qcfractal/components/external_files/socket.py create mode 100644 qcportal/qcportal/external_files/__init__.py create mode 100644 qcportal/qcportal/external_files/models.py diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-03-02afa97249c7_add_external_files_table.py b/qcfractal/qcfractal/alembic/versions/2025-01-03-02afa97249c7_add_external_files_table.py new file mode 100644 index 000000000..aec8be66c --- /dev/null +++ b/qcfractal/qcfractal/alembic/versions/2025-01-03-02afa97249c7_add_external_files_table.py @@ -0,0 +1,43 @@ +"""Add external files table + +Revision ID: 02afa97249c7 +Revises: c13116948b54 +Create Date: 2025-01-03 10:01:55.717905 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "02afa97249c7" +down_revision = "c13116948b54" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "external_file", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_on", sa.TIMESTAMP(), nullable=False), + sa.Column("status", sa.Enum("available", "processing", name="externalfilestatusenum"), nullable=False), + sa.Column("file_type", sa.Enum("dataset_attachment", name="externalfiletypeenum"), nullable=False), + sa.Column("bucket", sa.String(), nullable=False), + sa.Column("file_name", sa.String(), nullable=False), + sa.Column("object_key", sa.String(), nullable=False), + sa.Column("sha256sum", sa.String(), nullable=False), + sa.Column("file_size", sa.BigInteger(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("provenance", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("external_file") + # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/components/external_files/__init__.py b/qcfractal/qcfractal/components/external_files/__init__.py new file mode 100644 index 000000000..ee9c19c68 --- /dev/null +++ b/qcfractal/qcfractal/components/external_files/__init__.py @@ -0,0 +1 @@ +from .socket import ExternalFileSocket diff --git a/qcfractal/qcfractal/components/external_files/db_models.py b/qcfractal/qcfractal/components/external_files/db_models.py new file mode 100644 index 000000000..7b6d776e9 --- /dev/null +++ b/qcfractal/qcfractal/components/external_files/db_models.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from sqlalchemy import Column, Integer, String, Enum, TIMESTAMP, BigInteger +from sqlalchemy.dialects.postgresql import JSONB + +from qcfractal.db_socket.base_orm import BaseORM +from qcportal.external_files import ExternalFileStatusEnum, ExternalFileTypeEnum +from qcportal.utils import now_at_utc + + +class ExternalFileORM(BaseORM): + """ + Table for storing molecules + """ + + __tablename__ = "external_file" + + id = Column(Integer, primary_key=True) + file_type = Column(Enum(ExternalFileTypeEnum), nullable=False) + + created_on = Column(TIMESTAMP, default=now_at_utc, nullable=False) + status = Column(Enum(ExternalFileStatusEnum), nullable=False) + + file_name = Column(String, nullable=False) + description = Column(String, nullable=True) + provenance = Column(JSONB, nullable=False) + + sha256sum = Column(String, nullable=False) + file_size = Column(BigInteger, nullable=False) + + bucket = Column(String, nullable=False) + object_key = Column(String, nullable=False) + + __mapper_args__ = {"polymorphic_on": "file_type"} + + _qcportal_model_excludes__ = ["object_key", "bucket"] diff --git a/qcfractal/qcfractal/components/external_files/routes.py b/qcfractal/qcfractal/components/external_files/routes.py new file mode 100644 index 000000000..8747db508 --- /dev/null +++ b/qcfractal/qcfractal/components/external_files/routes.py @@ -0,0 +1,12 @@ +from flask import redirect + +from qcfractal.flask_app import storage_socket +from qcfractal.flask_app.api_v1.blueprint import api_v1 +from qcfractal.flask_app.api_v1.helpers import wrap_route + + +@api_v1.route("/external_files//download", methods=["GET"]) +@wrap_route("READ") +def download_external_file_v1(file_id: int): + _, url = storage_socket.external_files.get_url(file_id) + return redirect(url, code=302) diff --git a/qcfractal/qcfractal/components/external_files/socket.py b/qcfractal/qcfractal/components/external_files/socket.py new file mode 100644 index 000000000..1d84f7776 --- /dev/null +++ b/qcfractal/qcfractal/components/external_files/socket.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import hashlib +import logging +import os +import uuid +from typing import TYPE_CHECKING + +import boto3 + +from qcportal.exceptions import MissingDataError +from qcportal.external_files import ExternalFileTypeEnum, ExternalFileStatusEnum +from .db_models import ExternalFileORM + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + from qcfractal.db_socket.socket import SQLAlchemySocket + from typing import Optional, Dict, Any, Tuple, Union, BinaryIO + + +class ExternalFileSocket: + """ + Socket for managing/querying external files + """ + + def __init__(self, root_socket: SQLAlchemySocket): + self.root_socket = root_socket + self._logger = logging.getLogger(__name__) + self._s3_config = root_socket.qcf_config.s3 + + if not self._s3_config.enabled: + self._logger.info("S3 service for external files is not configured") + return + + self._s3_client = boto3.client( + "s3", + endpoint_url=self._s3_config.endpoint_url, + aws_access_key_id=self._s3_config.access_key_id, + aws_secret_access_key=self._s3_config.secret_access_key, + verify=self._s3_config.verify, + ) + + # may raise an exception (bad/missing credentials, etc) + server_bucket_info = self._s3_client.list_buckets() + server_buckets = {k["Name"] for k in server_bucket_info["Buckets"]} + self._logger.info(f"Found {len(server_buckets)} buckets on the S3 server/account") + + # Make sure the buckets we use exist + self._bucket_map = self._s3_config.bucket_map + if self._bucket_map.dataset_attachment not in server_buckets: + raise RuntimeError(f"Bucket {self._bucket_map.dataset_attachment} (for dataset attachments)") + + def _lookup_bucket(self, file_type: Union[ExternalFileTypeEnum, str]) -> str: + # Can sometimes be a string from sqlalchemy (if set because of polymorphic identity) + if isinstance(file_type, str): + return getattr(self._bucket_map, file_type) + elif isinstance(file_type, ExternalFileTypeEnum): + return getattr(self._bucket_map, file_type.value) + else: + raise ValueError("Unknown parameter type for lookup_bucket: ", type(file_type)) + + def add_data( + self, + file_data: BinaryIO, + file_orm: ExternalFileORM, + *, + session: Optional[Session] = None, + ) -> int: + """ + Add raw file data to the database. + + This will fill in the appropriate fields on the given ORM object. + The file_name and file_type must be filled in already. + + In this function, the `file_orm` will be added to the session and the session will be flushed. + This is done at the beginning to show that the file is "processing", and at the end when the + addition is completed. + + Parameters + ---------- + file_data + Binary data to be read from + file_orm + Existing ORM object that will be filled in with metadata + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + ID of the external file (which is also set in the given ORM object) + """ + + bucket = self._lookup_bucket(file_orm.file_type) + + self._logger.info( + f"Uploading data to S3 bucket {bucket}. file_name={file_orm.file_name} type={file_orm.file_type}" + ) + object_key = str(uuid.uuid4()) + sha256 = hashlib.sha256() + file_size = 0 + + multipart_upload = self._s3_client.create_multipart_upload(Bucket=bucket, Key=object_key) + upload_id = multipart_upload["UploadId"] + parts = [] + part_number = 1 + + with self.root_socket.optional_session(session) as session: + file_orm.status = ExternalFileStatusEnum.processing + file_orm.bucket = bucket + file_orm.object_key = object_key + file_orm.sha256sum = "" + file_orm.file_size = 0 + + session.add(file_orm) + session.flush() + + try: + while chunk := file_data.read(10 * 1024 * 1024): + sha256.update(chunk) + file_size += len(chunk) + + response = self._s3_client.upload_part( + Bucket=bucket, Key=object_key, PartNumber=part_number, UploadId=upload_id, Body=chunk + ) + parts.append({"PartNumber": part_number, "ETag": response["ETag"]}) + part_number += 1 + + self._s3_client.complete_multipart_upload( + Bucket=bucket, Key=object_key, UploadId=upload_id, MultipartUpload={"Parts": parts} + ) + + except Exception as e: + self._s3_client.abort_multipart_upload(Bucket=bucket, Key=object_key, UploadId=upload_id) + raise e + + self._logger.info(f"Uploading data to S3 bucket complete. Finishing writing metadata to db") + file_orm.status = ExternalFileStatusEnum.available + file_orm.sha256sum = sha256.hexdigest() + file_orm.file_size = file_size + session.flush() + + return file_orm.id + + def add_file( + self, + file_path: str, + file_orm: ExternalFileORM, + *, + session: Optional[Session] = None, + ) -> int: + """ + Add an existing file to the database + + See documentation for :ref:`add_data` for more information. + + Parameters + ---------- + file_path + Path to an existing file to be read from + file_orm + Existing ORM object that will be filled in with metadata + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + ID of the external file (which is also set in the given ORM object) + """ + + self._logger.info(f"Uploading {file_path} to S3. File size: {os.path.getsize(file_path)/1048576} MiB") + + with open(file_path, "rb") as f: + return self.add_data(f, file_orm, session=session) + + def get_metadata( + self, + file_id: int, + *, + session: Optional[Session] = None, + ) -> Dict[str, Any]: + """ + Obtain external file information + + Parameters + ---------- + file_id + ID for the external file + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + List of molecule data (as dictionaries) in the same order as the given ids. + If missing_ok is True, then this list will contain None where the molecule was missing. + """ + + with self.root_socket.optional_session(session, True) as session: + ef = session.get(ExternalFileORM, file_id) + if ef is None: + raise MissingDataError(f"Cannot find external file with id {file_id} in the database") + + return ef.model_dict() + + # def delete(self, file_id: int, *, session: Optional[Session] = None): + # """ + # Deletes an external file from the database and from remote storage + + # Parameters + # ---------- + # file_id + # ID of the external file to remove + # session + # An existing SQLAlchemy session to use. If None, one will be created. If an existing session + # is used, it will be flushed (but not committed) before returning from this function. + + # Returns + # ------- + # : + # Metadata about what was deleted and any errors that occurred + # """ + + # with self.root_socket.optional_session(session) as session: + # stmt = delete(ExternalFileORM).where(ExternalFileORM.id == file_id) + # session.execute(stmt) + + def get_url(self, file_id: int, *, session: Optional[Session] = None) -> Tuple[str, str]: + """ + Obtain an url that a user can use to download the file directly from the S3 bucket + + Will raise an exception if the file_id does not exist + + Parameters + ---------- + file_id + ID of the external file + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + File name and direct URL to the file + """ + + with self.root_socket.optional_session(session, True) as session: + ef = session.get(ExternalFileORM, file_id) + if ef is None: + raise MissingDataError(f"Cannot find external file with id {file_id} in the database") + + url = self._s3_client.generate_presigned_url( + ClientMethod="get_object", + Params={ + "Bucket": ef.bucket, + "Key": ef.object_key, + "ResponseContentDisposition": f'attachment; filename = "{ef.file_name}"', + }, + HttpMethod="GET", + ExpiresIn=120, + ) + + return ef.file_name, url diff --git a/qcfractal/qcfractal/components/register_all.py b/qcfractal/qcfractal/components/register_all.py index 80ac72dc2..cc24d092c 100644 --- a/qcfractal/qcfractal/components/register_all.py +++ b/qcfractal/qcfractal/components/register_all.py @@ -9,6 +9,7 @@ from .tasks import db_models, routes from .services import db_models from .internal_jobs import db_models, routes +from .external_files import db_models, routes from . import record_db_models, dataset_db_models, record_routes, dataset_routes from .singlepoint import record_db_models, dataset_db_models, routes diff --git a/qcfractal/qcfractal/config.py b/qcfractal/qcfractal/config.py index be82344cf..fadc58b7e 100644 --- a/qcfractal/qcfractal/config.py +++ b/qcfractal/qcfractal/config.py @@ -324,6 +324,37 @@ class Config(ConfigCommon): env_prefix = "QCF_API_" +class S3BucketMap(ConfigBase): + dataset_attachment: str = Field("dataset_attachment", description="Bucket to hold dataset views") + + +class S3Config(ConfigBase): + """ + Settings for using external files with S3 + """ + + enabled: bool = False + verify: bool = True + passthrough: bool = False + endpoint_url: Optional[str] = Field(None, description="S3 endpoint URL") + access_key_id: Optional[str] = Field(None, description="AWS/S3 access key") + secret_access_key: Optional[str] = Field(None, description="AWS/S3 secret key") + + bucket_map: S3BucketMap = Field(S3BucketMap(), description="Configuration for where to store various files") + + class Config(ConfigCommon): + env_prefix = "QCF_S3_" + + @root_validator() + def _check_enabled(cls, values): + if values.get("enabled", False) is True: + for key in ["endpoint_url", "access_key_id", "secret_access_key"]: + if values.get(key, None) is None: + raise ValueError(f"S3 enabled but {key} not set") + + return values + + class FractalConfig(ConfigBase): """ Fractal Server settings @@ -334,6 +365,11 @@ class FractalConfig(ConfigBase): description="The base directory to use as the default for some options (logs, etc). Default is the location of the config file.", ) + temporary_dir: Optional[str] = Field( + None, + description="Temporary directory to use for things such as view creation. If None, uses system default. This may require a lot of space!", + ) + # Info for the REST interface name: str = Field("QCFractal Server", description="The QCFractal server name") @@ -413,6 +449,7 @@ class FractalConfig(ConfigBase): # Other settings blocks database: DatabaseConfig = Field(..., description="Configuration of the settings for the database") api: WebAPIConfig = Field(..., description="Configuration of the REST interface") + s3: S3Config = Field(S3Config(), description="Configuration of the S3 file storage (optional)") api_limits: APILimitConfig = Field(..., description="Configuration of the limits to the api") auto_reset: AutoResetConfig = Field(..., description="Configuration for automatic resetting of tasks") diff --git a/qcfractal/qcfractal/db_socket/socket.py b/qcfractal/qcfractal/db_socket/socket.py index 3635439b0..a128f7f01 100644 --- a/qcfractal/qcfractal/db_socket/socket.py +++ b/qcfractal/qcfractal/db_socket/socket.py @@ -100,6 +100,7 @@ def checkout(dbapi_connection, connection_record, connection_proxy): from ..components.managers.socket import ManagerSocket from ..components.tasks.socket import TaskSocket from ..components.services.socket import ServiceSocket + from ..components.external_files import ExternalFileSocket from ..components.record_socket import RecordSocket from ..components.dataset_socket import DatasetSocket @@ -111,6 +112,7 @@ def checkout(dbapi_connection, connection_record, connection_proxy): self.molecules = MoleculeSocket(self) self.datasets = DatasetSocket(self) self.records = RecordSocket(self) + self.external_files = ExternalFileSocket(self) self.tasks = TaskSocket(self) self.services = ServiceSocket(self) self.managers = ManagerSocket(self) diff --git a/qcportal/qcportal/external_files/__init__.py b/qcportal/qcportal/external_files/__init__.py new file mode 100644 index 000000000..632ac5a16 --- /dev/null +++ b/qcportal/qcportal/external_files/__init__.py @@ -0,0 +1 @@ +from .models import ExternalFileStatusEnum, ExternalFileTypeEnum, ExternalFile diff --git a/qcportal/qcportal/external_files/models.py b/qcportal/qcportal/external_files/models.py new file mode 100644 index 000000000..53eab4865 --- /dev/null +++ b/qcportal/qcportal/external_files/models.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Dict, Any, Optional + +try: + from pydantic.v1 import BaseModel, Extra, validator, PrivateAttr, Field +except ImportError: + from pydantic import BaseModel, Extra, validator, PrivateAttr, Field + + +class ExternalFileStatusEnum(str, Enum): + """ + The state of an external file + """ + + available = "available" + processing = "processing" + + +class ExternalFileTypeEnum(str, Enum): + """ + The state of an external file + """ + + dataset_attachment = "dataset_attachment" + + +class ExternalFile(BaseModel): + id: int + file_type: ExternalFileTypeEnum + + created_on: datetime + status: ExternalFileStatusEnum + + file_name: str + description: Optional[str] + provenance: Dict[str, Any] + + sha256sum: str + file_size: int From 52034951f6da39e7c965b879744e754bcaebdd72 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 6 Jan 2025 20:20:25 -0500 Subject: [PATCH 07/26] Don't allow redirects in base client --- qcportal/qcportal/client_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/qcportal/qcportal/client_base.py b/qcportal/qcportal/client_base.py index f036e4d96..98dbe2f5b 100644 --- a/qcportal/qcportal/client_base.py +++ b/qcportal/qcportal/client_base.py @@ -336,7 +336,9 @@ def _request( try: if not allow_retries: - r = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout) + r = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False) + if r.is_redirect: + raise RuntimeError("Redirection is not allowed") else: current_retries = 0 while True: From f283de9e806c1c53a0ca0dfac123f01ffd0d891d Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 6 Jan 2025 15:13:30 -0500 Subject: [PATCH 08/26] Functions for downloading external file in client --- .../components/external_files/routes.py | 6 +++ qcportal/qcportal/client.py | 48 +++++++++++++++++++ qcportal/qcportal/client_base.py | 34 +++++++++++++ 3 files changed, 88 insertions(+) diff --git a/qcfractal/qcfractal/components/external_files/routes.py b/qcfractal/qcfractal/components/external_files/routes.py index 8747db508..9e7d077ab 100644 --- a/qcfractal/qcfractal/components/external_files/routes.py +++ b/qcfractal/qcfractal/components/external_files/routes.py @@ -5,6 +5,12 @@ from qcfractal.flask_app.api_v1.helpers import wrap_route +@api_v1.route("/external_files/", methods=["GET"]) +@wrap_route("READ") +def get_external_file_metadata_v1(file_id: int): + return storage_socket.external_files.get_metadata(file_id) + + @api_v1.route("/external_files//download", methods=["GET"]) @wrap_route("READ") def download_external_file_v1(file_id: int): diff --git a/qcportal/qcportal/client.py b/qcportal/qcportal/client.py index 24c6857cd..e1fe97cee 100644 --- a/qcportal/qcportal/client.py +++ b/qcportal/qcportal/client.py @@ -2,12 +2,14 @@ import logging import math +import os from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, Union, Sequence, Iterable, TypeVar, Type from tabulate import tabulate from qcportal.cache import DatasetCache, read_dataset_metadata +from qcportal.external_files import ExternalFile from qcportal.gridoptimization import ( GridoptimizationKeywords, GridoptimizationAddBody, @@ -325,6 +327,52 @@ def delete_dataset(self, dataset_id: int, delete_records: bool): params = DatasetDeleteParams(delete_records=delete_records) return self.make_request("delete", f"api/v1/datasets/{dataset_id}", None, url_params=params) + ############################################################## + # External files + ############################################################## + def download_external_file(self, file_id: int, destination_path: str, overwrite: bool = False) -> Tuple[int, str]: + """ + Downloads an external file to the given path + + The file size and checksum will be checked against the metadata stored on the server + + Parameters + ---------- + file_id + ID of the file to obtain + destination_path + Full path to the destination file (including filename) + overwrite + If True, allow for overwriting an existing file. If False, and a file already exists at the given + destination path, an exception will be raised. + + Returns + ------- + : + A tuple of file size and sha256 checksum. + + """ + meta_url = f"api/v1/external_files/{file_id}" + download_url = f"api/v1/external_files/{file_id}/download" + + # Check for local file existence before doing any requests + if os.path.exists(destination_path) and not overwrite: + raise RuntimeError(f"File already exists at {destination_path}. To overwrite, use `overwrite=True`") + + # First, get the metadata + file_info = self.make_request("get", meta_url, ExternalFile) + + # Now actually download the file + file_size, file_sha256 = self.download_file(download_url, destination_path, overwrite=overwrite) + + if file_size != file_info.file_size: + raise RuntimeError(f"Inconsistent file size. Expected {file_info.file_size}, got {file_size}") + + if file_sha256 != file_info.sha256sum: + raise RuntimeError(f"Inconsistent file checksum. Expected {file_info.sha256sum}, got {file_sha256}") + + return file_size, file_sha256 + ############################################################## # Molecules ############################################################## diff --git a/qcportal/qcportal/client_base.py b/qcportal/qcportal/client_base.py index 98dbe2f5b..c60725831 100644 --- a/qcportal/qcportal/client_base.py +++ b/qcportal/qcportal/client_base.py @@ -20,7 +20,9 @@ except ImportError: import pydantic import requests +from typing import Tuple import yaml +import hashlib from packaging.version import parse as parse_version from . import __version__ @@ -430,6 +432,38 @@ def make_request( else: return pydantic.parse_obj_as(response_model, d) + def download_file(self, endpoint: str, destination_path: str, overwrite: bool = False) -> Tuple[int, str]: + + sha256 = hashlib.sha256() + file_size = 0 + + # Remove if overwrite=True. This allows for any processes still using the old file to keep using it + # (at least on linux) + if os.path.exists(destination_path): + if overwrite: + os.remove(destination_path) + else: + raise RuntimeError(f"File already exists at {destination_path}. To overwrite, use `overwrite=True`") + + full_uri = self.address + endpoint + response = self._req_session.get(full_uri, stream=True, allow_redirects=False) + + if response.is_redirect: + # send again, but using a plain requests object + # that way, we don't pass the JWT to someone else + new_location = response.headers["Location"] + response = requests.get(new_location, stream=True, allow_redirects=True) + + response.raise_for_status() + with open(destination_path, "wb") as f: + for chunk in response.iter_content(chunk_size=None): + if chunk: + f.write(chunk) + sha256.update(chunk) + file_size += len(chunk) + + return file_size, sha256.hexdigest() + def ping(self) -> bool: """ Pings the server to see if it is up From 272dc66f42fc032169b9fd7af31977575855b74e Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Wed, 8 Jan 2025 20:27:24 -0500 Subject: [PATCH 09/26] Enable passthrough for external files --- .../components/external_files/routes.py | 17 +++++++-- .../components/external_files/socket.py | 35 ++++++++++++++++++- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/qcfractal/qcfractal/components/external_files/routes.py b/qcfractal/qcfractal/components/external_files/routes.py index 9e7d077ab..6af0f0d6a 100644 --- a/qcfractal/qcfractal/components/external_files/routes.py +++ b/qcfractal/qcfractal/components/external_files/routes.py @@ -1,4 +1,4 @@ -from flask import redirect +from flask import redirect, Response, stream_with_context, current_app from qcfractal.flask_app import storage_socket from qcfractal.flask_app.api_v1.blueprint import api_v1 @@ -14,5 +14,16 @@ def get_external_file_metadata_v1(file_id: int): @api_v1.route("/external_files//download", methods=["GET"]) @wrap_route("READ") def download_external_file_v1(file_id: int): - _, url = storage_socket.external_files.get_url(file_id) - return redirect(url, code=302) + passthrough = current_app.config["QCFRACTAL_CONFIG"].s3.passthrough + + if passthrough: + file_name, streamer_func = storage_socket.external_files.get_file_streamer(file_id) + + return Response( + stream_with_context(streamer_func()), + content_type="application/octet-stream", + headers={"Content-Disposition": f'attachment; filename="{file_name}"'}, + ) + else: + _, url = storage_socket.external_files.get_url(file_id) + return redirect(url, code=302) diff --git a/qcfractal/qcfractal/components/external_files/socket.py b/qcfractal/qcfractal/components/external_files/socket.py index 1d84f7776..bf0603cc2 100644 --- a/qcfractal/qcfractal/components/external_files/socket.py +++ b/qcfractal/qcfractal/components/external_files/socket.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.session import Session from qcfractal.db_socket.socket import SQLAlchemySocket - from typing import Optional, Dict, Any, Tuple, Union, BinaryIO + from typing import Optional, Dict, Any, Tuple, Union, BinaryIO, Callable, Generator class ExternalFileSocket: @@ -266,3 +266,36 @@ def get_url(self, file_id: int, *, session: Optional[Session] = None) -> Tuple[s ) return ef.file_name, url + + def get_file_streamer( + self, file_id: int, *, session: Optional[Session] = None + ) -> Tuple[str, Callable[[], Generator[bytes, None, None]]]: + """ + Returns a function that streams a file from an S3 bucket. + + Parameters + ---------- + file_id + ID of the external file + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + The recommended filename and a generator function that yields chunks of the file + """ + + with self.root_socket.optional_session(session, True) as session: + ef = session.get(ExternalFileORM, file_id) + if ef is None: + raise MissingDataError(f"Cannot find external file with id {file_id} in the database") + + s3_obj = self._s3_client.get_object(Bucket=ef.bucket, Key=ef.object_key) + + def _generator_func(): + for chunk in s3_obj["Body"].iter_chunks(chunk_size=(8 * 1048576)): + yield chunk + + return ef.file_name, _generator_func From c4f1d2b3c6b03d5a7e2fdd9c02b2363a0e459b65 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Thu, 9 Jan 2025 08:38:41 -0500 Subject: [PATCH 10/26] Make boto3 package optional --- qcfractal/pyproject.toml | 3 +++ .../qcfractal/components/external_files/socket.py | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/qcfractal/pyproject.toml b/qcfractal/pyproject.toml index 89fc563bc..92f999758 100644 --- a/qcfractal/pyproject.toml +++ b/qcfractal/pyproject.toml @@ -42,6 +42,9 @@ geoip = [ snowflake = [ "qcfractalcompute" ] +s3 = [ + "boto3" +] [project.urls] diff --git a/qcfractal/qcfractal/components/external_files/socket.py b/qcfractal/qcfractal/components/external_files/socket.py index bf0603cc2..eade40799 100644 --- a/qcfractal/qcfractal/components/external_files/socket.py +++ b/qcfractal/qcfractal/components/external_files/socket.py @@ -1,12 +1,18 @@ from __future__ import annotations import hashlib +import importlib import logging import os import uuid from typing import TYPE_CHECKING -import boto3 +# Torsiondrive package is optional +_boto3_spec = importlib.util.find_spec("boto3") + +if _boto3_spec is not None: + boto3 = importlib.util.module_from_spec(_boto3_spec) + _boto3_spec.loader.exec_module(boto3) from qcportal.exceptions import MissingDataError from qcportal.external_files import ExternalFileTypeEnum, ExternalFileStatusEnum @@ -31,6 +37,8 @@ def __init__(self, root_socket: SQLAlchemySocket): if not self._s3_config.enabled: self._logger.info("S3 service for external files is not configured") return + elif _boto3_spec is None: + raise RuntimeError("boto3 package is required for S3 support") self._s3_client = boto3.client( "s3", From 3298c99a07e8c3e7821ad27b88c3160d6b9165cf Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Fri, 3 Jan 2025 15:31:02 -0500 Subject: [PATCH 11/26] Add view creation and attachment to datasets --- ...285e3620fd_add_dataset_attachment_table.py | 38 ++++ qcfractal/qcfractal/components/create_view.py | 132 ----------- .../qcfractal/components/dataset_db_models.py | 25 ++ .../components/dataset_processing/__init__.py | 1 + .../components/dataset_processing/views.py | 173 ++++++++++++++ .../qcfractal/components/dataset_routes.py | 21 +- .../qcfractal/components/dataset_socket.py | 213 +++++++++++++++++- qcportal/qcportal/dataset_models.py | 87 ++++++- 8 files changed, 552 insertions(+), 138 deletions(-) create mode 100644 qcfractal/qcfractal/alembic/versions/2025-01-03-84285e3620fd_add_dataset_attachment_table.py delete mode 100644 qcfractal/qcfractal/components/create_view.py create mode 100644 qcfractal/qcfractal/components/dataset_processing/__init__.py create mode 100644 qcfractal/qcfractal/components/dataset_processing/views.py diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-03-84285e3620fd_add_dataset_attachment_table.py b/qcfractal/qcfractal/alembic/versions/2025-01-03-84285e3620fd_add_dataset_attachment_table.py new file mode 100644 index 000000000..cc3ada892 --- /dev/null +++ b/qcfractal/qcfractal/alembic/versions/2025-01-03-84285e3620fd_add_dataset_attachment_table.py @@ -0,0 +1,38 @@ +"""Add dataset attachment table + +Revision ID: 84285e3620fd +Revises: 02afa97249c7 +Create Date: 2025-01-03 10:04:16.201770 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "84285e3620fd" +down_revision = "02afa97249c7" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "dataset_attachment", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("dataset_id", sa.Integer(), nullable=False), + sa.Column("attachment_type", sa.Enum("other", "view", name="datasetattachmenttype"), nullable=False), + sa.ForeignKeyConstraint(["dataset_id"], ["base_dataset.id"], ondelete="cascade"), + sa.ForeignKeyConstraint(["id"], ["external_file.id"], ondelete="cascade"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_dataset_attachment_dataset_id", "dataset_attachment", ["dataset_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_dataset_attachment_dataset_id", table_name="dataset_attachment") + op.drop_table("dataset_attachment") + # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/components/create_view.py b/qcfractal/qcfractal/components/create_view.py deleted file mode 100644 index 6fceb99eb..000000000 --- a/qcfractal/qcfractal/components/create_view.py +++ /dev/null @@ -1,132 +0,0 @@ -import os - -from sqlalchemy import select, create_engine, Column, String, ForeignKey, LargeBinary -from sqlalchemy.orm import selectinload, sessionmaker, declarative_base - -from qcfractal.db_socket.socket import SQLAlchemySocket -from qcportal.compression import compress, CompressionEnum -from qcportal.serialization import serialize - -ViewBaseORM = declarative_base() - - -class DatasetViewEntry(ViewBaseORM): - __tablename__ = "dataset_entry" - name = Column(String, primary_key=True) - data = Column(LargeBinary, nullable=False) - - -class DatasetViewSpecification(ViewBaseORM): - __tablename__ = "dataset_specification" - name = Column(String, primary_key=True) - data = Column(LargeBinary, nullable=False) - - -class DatasetViewRecord(ViewBaseORM): - __tablename__ = "dataset_record" - entry_name = Column(String, ForeignKey(DatasetViewEntry.name), primary_key=True) - specification_name = Column(String, ForeignKey(DatasetViewSpecification.name), primary_key=True) - data = Column(LargeBinary, nullable=False) - - -class DatasetViewMetadata(ViewBaseORM): - __tablename__ = "dataset_metadata" - - key = Column(String, ForeignKey(DatasetViewEntry.name), primary_key=True) - value = Column(LargeBinary, nullable=False) - - -def _serialize_orm(orm, exclude=None): - s_data = serialize(orm.model_dict(exclude=exclude), "application/msgpack") - c_data, _, _ = compress(s_data, CompressionEnum.zstd, 7) - return c_data - - -def create_dataset_view(dataset_id: int, socket: SQLAlchemySocket, view_file_path: str): - if os.path.exists(view_file_path): - raise RuntimeError(f"File {view_file_path} exists - will not overwrite") - - if os.path.isdir(view_file_path): - raise RuntimeError(f"{view_file_path} is a directory") - - uri = "sqlite:///" + view_file_path - engine = create_engine(uri) - ViewSession = sessionmaker(bind=engine) - - ViewBaseORM.metadata.create_all(engine) - - view_session = ViewSession() - - with socket.session_scope(True) as fractal_session: - ds_type = socket.datasets.lookup_type(dataset_id) - ds_socket = socket.datasets.get_socket(ds_type) - - dataset_orm = ds_socket.dataset_orm - entry_orm = ds_socket.entry_orm - specification_orm = ds_socket.specification_orm - record_item_orm = ds_socket.record_item_orm - - stmt = select(dataset_orm).where(dataset_orm.id == dataset_id) - stmt = stmt.options(selectinload("*")) - ds_orm = fractal_session.execute(stmt).scalar_one() - - # Metadata - metadata_bytes = _serialize_orm(ds_orm) - metadata_orm = DatasetViewMetadata(key="raw_data", value=metadata_bytes) - view_session.add(metadata_orm) - view_session.commit() - - # Entries - stmt = select(entry_orm) - stmt = stmt.options(selectinload("*")) - stmt = stmt.where(entry_orm.dataset_id == dataset_id) - entries = fractal_session.execute(stmt).scalars().all() - - for entry in entries: - entry_bytes = _serialize_orm(entry) - entry_orm = DatasetViewEntry(name=entry.name, data=entry_bytes) - view_session.add(entry_orm) - - view_session.commit() - - # Specifications - stmt = select(specification_orm) - stmt = stmt.options(selectinload("*")) - stmt = stmt.where(specification_orm.dataset_id == dataset_id) - specs = fractal_session.execute(stmt).scalars().all() - - for spec in specs: - spec_bytes = _serialize_orm(spec) - specification_orm = DatasetViewSpecification(name=spec.name, data=spec_bytes) - view_session.add(specification_orm) - - view_session.commit() - - base_stmt = select(record_item_orm).where(record_item_orm.dataset_id == dataset_id) - base_stmt = base_stmt.options(selectinload("*")) - base_stmt = base_stmt.order_by(record_item_orm.record_id.asc()) - - skip = 0 - while True: - stmt = base_stmt.offset(skip).limit(10) - batch = fractal_session.execute(stmt).scalars() - - count = 0 - for item_orm in batch: - item_bytes = _serialize_orm(item_orm) - - view_record_orm = DatasetViewRecord( - entry_name=item_orm.entry_name, - specification_name=item_orm.specification_name, - data=item_bytes, - ) - - view_session.add(view_record_orm) - count += 1 - - view_session.commit() - - if count == 0: - break - - skip += count diff --git a/qcfractal/qcfractal/components/dataset_db_models.py b/qcfractal/qcfractal/components/dataset_db_models.py index 261f7b877..468f04d11 100644 --- a/qcfractal/qcfractal/components/dataset_db_models.py +++ b/qcfractal/qcfractal/components/dataset_db_models.py @@ -13,12 +13,15 @@ ForeignKey, ForeignKeyConstraint, UniqueConstraint, + Enum, ) from sqlalchemy.orm import relationship from sqlalchemy.orm.collections import attribute_keyed_dict from qcfractal.components.auth.db_models import UserIDMapSubquery, GroupIDMapSubquery, UserORM, GroupORM +from qcfractal.components.external_files.db_models import ExternalFileORM from qcfractal.db_socket import BaseORM, MsgpackExt +from qcportal.dataset_models import DatasetAttachmentType class BaseDatasetORM(BaseORM): @@ -78,6 +81,13 @@ class BaseDatasetORM(BaseORM): passive_deletes=True, ) + attachments = relationship( + "DatasetAttachmentORM", + cascade="all, delete-orphan", + passive_deletes=True, + lazy="selectin", + ) + __table_args__ = ( UniqueConstraint("dataset_type", "lname", name="ux_base_dataset_dataset_type_lname"), Index("ix_base_dataset_dataset_type", "dataset_type"), @@ -129,3 +139,18 @@ class ContributedValuesORM(BaseORM): __table_args__ = (Index("ix_contributed_values_dataset_id", "dataset_id"),) _qcportal_model_excludes = ["dataset_id"] + + +class DatasetAttachmentORM(ExternalFileORM): + __tablename__ = "dataset_attachment" + + id = Column(Integer, ForeignKey(ExternalFileORM.id, ondelete="cascade"), primary_key=True) + dataset_id = Column(Integer, ForeignKey("base_dataset.id", ondelete="cascade"), nullable=False) + + attachment_type = Column(Enum(DatasetAttachmentType), nullable=False) + + __mapper_args__ = {"polymorphic_identity": "dataset_attachment"} + + __table_args__ = (Index("ix_dataset_attachment_dataset_id", "dataset_id"),) + + _qcportal_model_excludes = ["dataset_id"] diff --git a/qcfractal/qcfractal/components/dataset_processing/__init__.py b/qcfractal/qcfractal/components/dataset_processing/__init__.py new file mode 100644 index 000000000..4e62cdf45 --- /dev/null +++ b/qcfractal/qcfractal/components/dataset_processing/__init__.py @@ -0,0 +1 @@ +from .views import create_view_file diff --git a/qcfractal/qcfractal/components/dataset_processing/views.py b/qcfractal/qcfractal/components/dataset_processing/views.py new file mode 100644 index 000000000..a2d094dfd --- /dev/null +++ b/qcfractal/qcfractal/components/dataset_processing/views.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import os +from collections import defaultdict +from typing import TYPE_CHECKING + +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from qcfractal.components.record_db_models import BaseRecordORM +from qcportal.dataset_models import BaseDataset +from qcportal.record_models import RecordStatusEnum, BaseRecord +from qcportal.utils import chunk_iterable +from qcportal.cache import DatasetCache + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + from qcfractal.components.internal_jobs.status import JobProgress + from qcfractal.db_socket.socket import SQLAlchemySocket + from typing import Optional, Iterable + from typing import Iterable + from sqlalchemy.orm.session import Session + + +def create_view_file( + session: Session, + socket: SQLAlchemySocket, + dataset_id: int, + dataset_type: str, + output_path: str, + status: Optional[Iterable[RecordStatusEnum]] = None, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + *, + include_children: bool = True, + job_progress: Optional[JobProgress] = None, +): + """ + Creates a view file for a dataset + + Note: the job progress object will be filled to 90% to leave room for uploading + + Parameters + ---------- + session + An existing SQLAlchemy session to use. + socket + Full SQLAlchemy socket to use for getting records + dataset_type + Type of the underlying dataset to create a view of (as a string) + dataset_id + ID of the dataset to create the view for + output_path + Full path (including filename) to output the view data to. Must not already exist + status + List of statuses to include. Default is to include records with any status + include + List of specific record fields to include in the export. Default is to include most fields + exclude + List of specific record fields to exclude from the export. Defaults to excluding none. + include_children + Specifies whether child records associated with the main records should also be included (recursively) + in the view file. + job_progress + Object used to track the progress of the job + """ + + if os.path.exists(output_path): + raise RuntimeError(f"File {output_path} exists - will not overwrite") + + if os.path.isdir(output_path): + raise RuntimeError(f"{output_path} is a directory") + + ds_socket = socket.datasets.get_socket(dataset_type) + ptl_dataset_type = BaseDataset.get_subclass(dataset_type) + + ptl_entry_type = ptl_dataset_type._entry_type + ptl_specification_type = ptl_dataset_type._specification_type + + view_db = DatasetCache(output_path, read_only=False, dataset_type=ptl_dataset_type) + + stmt = select(ds_socket.dataset_orm).where(ds_socket.dataset_orm.id == dataset_id) + stmt = stmt.options(selectinload("*")) + ds_orm = session.execute(stmt).scalar_one() + + # Metadata + view_db.update_metadata("dataset_metadata", ds_orm.model_dict()) + + # Entries + if job_progress is not None: + job_progress.update_progress(0, "Processing dataset entries") + + stmt = select(ds_socket.entry_orm) + stmt = stmt.options(selectinload("*")) + stmt = stmt.where(ds_socket.entry_orm.dataset_id == dataset_id) + + entries = session.execute(stmt).scalars().all() + entries = [e.to_model(ptl_entry_type) for e in entries] + view_db.update_entries(entries) + + if job_progress is not None: + job_progress.update_progress(5, "Processing dataset specifications") + + # Specifications + stmt = select(ds_socket.specification_orm) + stmt = stmt.options(selectinload("*")) + stmt = stmt.where(ds_socket.specification_orm.dataset_id == dataset_id) + + specs = session.execute(stmt).scalars().all() + specs = [s.to_model(ptl_specification_type) for s in specs] + view_db.update_specifications(specs) + + if job_progress is not None: + job_progress.update_progress(10, "Loading record information") + + # Now all the records + stmt = select(ds_socket.record_item_orm).where(ds_socket.record_item_orm.dataset_id == dataset_id) + stmt = stmt.order_by(ds_socket.record_item_orm.record_id.asc()) + record_items = session.execute(stmt).scalars().all() + + record_ids = set(ri.record_id for ri in record_items) + all_ids = set(record_ids) + + if include_children: + # Get all the children ids + children_ids = socket.records.get_children_ids(session, record_ids) + all_ids |= set(children_ids) + + ############################################################################ + # Determine the record types of all the ids (top-level and children if desired) + ############################################################################ + stmt = select(BaseRecordORM.id, BaseRecordORM.record_type) + + if status is not None: + stmt = stmt.where(BaseRecordORM.status.in_(status)) + + # Sort into a dictionary with keys being the record type + record_type_map = defaultdict(list) + + for id_chunk in chunk_iterable(all_ids, 500): + stmt2 = stmt.where(BaseRecordORM.id.in_(id_chunk)) + for record_id, record_type in session.execute(stmt2).yield_per(100): + record_type_map[record_type].append(record_id) + + if job_progress is not None: + job_progress.update_progress(15, "Processing individual records") + + ############################################################################ + # Actually fetch the record data now + # We go one over the different types of records, then load them in batches + ############################################################################ + record_count = len(all_ids) + finished_count = 0 + + for record_type_str, record_ids in record_type_map.items(): + record_socket = socket.records.get_socket(record_type_str) + record_type = BaseRecord.get_subclass(record_type_str) + + for id_chunk in chunk_iterable(record_ids, 200): + record_dicts = record_socket.get(id_chunk, include=include, exclude=exclude, session=session) + record_data = [record_type(**r) for r in record_dicts] + view_db.update_records(record_data) + + finished_count += len(id_chunk) + if job_progress is not None: + # Fraction of the 75% left over (15 to start, 10 left over for uploading) + job_progress.update_progress( + 15 + int(75 * finished_count / record_count), "Processing individual records" + ) + + # Update the dataset <-> record association + record_info = [(ri.entry_name, ri.specification_name, ri.record_id) for ri in record_items] + view_db.update_dataset_records(record_info) diff --git a/qcfractal/qcfractal/components/dataset_routes.py b/qcfractal/qcfractal/components/dataset_routes.py index e39381387..59c7f4521 100644 --- a/qcfractal/qcfractal/components/dataset_routes.py +++ b/qcfractal/qcfractal/components/dataset_routes.py @@ -1,4 +1,4 @@ -from typing import Dict, Any +from typing import Dict from flask import current_app, g @@ -12,6 +12,7 @@ DatasetFetchRecordsBody, DatasetFetchEntryBody, DatasetFetchSpecificationBody, + DatasetCreateViewBody, DatasetSubmitBody, DatasetDeleteStrBody, DatasetRecordModifyBody, @@ -161,6 +162,24 @@ def modify_dataset_metadata_v1(dataset_type: str, dataset_id: int, body_data: Da return ds_socket.update_metadata(dataset_id, new_metadata=body_data) +######################### +# Views +######################### +@api_v1.route("/datasets///create_view", methods=["POST"]) +@wrap_route("WRITE") +def create_dataset_view_v1(dataset_type: str, dataset_id: int, body_data: DatasetCreateViewBody): + return storage_socket.datasets.add_create_view_attachment_job( + dataset_id, + dataset_type, + description=body_data.description, + provenance=body_data.provenance, + status=body_data.status, + include=body_data.include, + exclude=body_data.exclude, + include_children=body_data.include_children, + ) + + ######################### # Computation submission ######################### diff --git a/qcfractal/qcfractal/components/dataset_socket.py b/qcfractal/qcfractal/components/dataset_socket.py index 986ee1d71..4786b8ae0 100644 --- a/qcfractal/qcfractal/components/dataset_socket.py +++ b/qcfractal/qcfractal/components/dataset_socket.py @@ -1,29 +1,37 @@ from __future__ import annotations import logging +import os +import tempfile from typing import TYPE_CHECKING from sqlalchemy import select, delete, func, union, text, and_ from sqlalchemy.orm import load_only, lazyload, joinedload, with_polymorphic from sqlalchemy.orm.attributes import flag_modified -from qcfractal.components.dataset_db_models import BaseDatasetORM, ContributedValuesORM +from qcfractal.components.dataset_db_models import BaseDatasetORM, ContributedValuesORM, DatasetAttachmentORM from qcfractal.components.record_db_models import BaseRecordORM from qcfractal.db_socket.helpers import ( get_general, get_query_proj_options, ) -from qcportal.exceptions import AlreadyExistsError, MissingDataError +from qcportal.dataset_models import DatasetAttachmentType +from qcportal.exceptions import AlreadyExistsError, MissingDataError, UserReportableError from qcportal.metadata_models import InsertMetadata, DeleteMetadata, UpdateMetadata from qcportal.record_models import RecordStatusEnum, PriorityEnum -from qcportal.utils import chunk_iterable +from qcportal.utils import chunk_iterable, now_at_utc +from qcfractal.components.dataset_processing import create_view_file if TYPE_CHECKING: from sqlalchemy.orm.session import Session from qcportal.dataset_models import DatasetModifyMetadata + from qcfractal.components.internal_jobs.status import JobProgress from qcfractal.db_socket.socket import SQLAlchemySocket from qcfractal.db_socket.base_orm import BaseORM from typing import Dict, Any, Optional, Sequence, Iterable, Tuple, List, Union + from typing import Iterable, List, Dict, Any + from qcfractal.db_socket.socket import SQLAlchemySocket + from sqlalchemy.orm.session import Session class BaseDatasetSocket: @@ -56,6 +64,8 @@ def __init__( # Use the identity from the ORM object. This keeps everything consistent self.dataset_type = self.dataset_orm.__mapper_args__["polymorphic_identity"] + self._logger = logging.getLogger(__name__) + def _add_specification(self, session, specification) -> Tuple[InsertMetadata, Optional[int]]: raise NotImplementedError("_add_specification must be overridden by the derived class") @@ -1649,3 +1659,200 @@ def get_contributed_values(self, dataset_id: int, *, session: Optional[Session] with self.root_socket.optional_session(session, True) as session: cv = session.execute(stmt).scalars().all() return [x.model_dict() for x in cv] + + def attach_file( + self, + dataset_id: int, + attachment_type: DatasetAttachmentType, + file_path: str, + file_name: str, + description: Optional[str], + provenance: Dict[str, Any], + *, + session: Optional[Session] = None, + ) -> int: + """ + Attach a file to a dataset + + This function uploads the specified file and associates it with the dataset by creating + a corresponding dataset attachment record. This operation requires S3 storage to be enabled. + + Parameters + ---------- + dataset_id + The ID of the dataset to which the file will be attached. + attachment_type + The type of attachment that categorizes the file being added. + file_path + The local file system path to the file that needs to be uploaded. + file_name + The name of the file to be used in the attachment record. This is the filename that is + recommended to the user by default. + description + An optional description of the file + provenance + A dictionary containing metadata regarding the origin or history of the file. + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Raises + ------ + UserReportableError + Raised if S3 storage is not enabled + """ + + if not self.root_socket.qcf_config.s3.enabled: + raise UserReportableError("S3 storage is not enabled. Can not attach file to a dataset") + + self._logger.info(f"Uploading/Attaching dataset-related file: {file_path}") + with self.root_socket.optional_session(session) as session: + ef = DatasetAttachmentORM( + dataset_id=dataset_id, + attachment_type=attachment_type, + file_name=file_name, + description=description, + provenance=provenance, + ) + + file_id = self.root_socket.external_files.add_file(file_path, ef, session=session) + self._logger.info(f"Dataset attachment {file_path} successfully uploaded to S3. ID is {file_id}") + return file_id + + def create_view_attachment( + self, + dataset_id: int, + dataset_type: str, + description: Optional[str], + provenance: Dict[str, Any], + status: Optional[Iterable[RecordStatusEnum]] = None, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + *, + include_children: bool = True, + job_progress: Optional[JobProgress] = None, + session: Optional[Session] = None, + ): + """ + Creates a dataset view and attaches it to the dataset + + Uses a temporary directory within the globally-configured `temporary_dir` + + Parameters + ---------- + dataset_id : int + ID of the dataset to create the view for + dataset_type + Type of dataset the ID is + description + Optional string describing the view file + provenance + Dictionary with any metadata or other information about the view. Information regarding + the options used to create the view will be added. + status + List of statuses to include. Default is to include records with any status + include + List of specific record fields to include in the export. Default is to include most fields + exclude + List of specific record fields to exclude from the export. Defaults to excluding none. + include_children + Specifies whether child records associated with the main records should also be included (recursively) + in the view file. + job_progress + Object used to track the progress of the job + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + """ + + if not self.root_socket.qcf_config.s3.enabled: + raise UserReportableError("S3 storage is not enabled. Can not not create view") + + # Add the options for the view creation to the provenance + provenance = provenance | { + "options": { + "status": status, + "include": include, + "exclude": exclude, + "include_children": include_children, + } + } + + with tempfile.TemporaryDirectory(dir=self.root_socket.qcf_config.temporary_dir) as tmpdir: + self._logger.info(f"Using temporary directory {tmpdir} for view creation") + + file_name = f"dataset_{dataset_id}_view.sqlite" + tmp_file_path = os.path.join(tmpdir, file_name) + + create_view_file( + session, + self.root_socket, + dataset_id, + dataset_type, + tmp_file_path, + status=status, + include=include, + exclude=exclude, + include_children=include_children, + job_progress=job_progress, + ) + + self._logger.info(f"View file created. File size is {os.path.getsize(tmp_file_path)/1048576} MiB.") + + if job_progress is not None: + job_progress.update_progress(90, "Uploading view file to S3") + + file_id = self.attach_file( + dataset_id, DatasetAttachmentType.view, tmp_file_path, file_name, description, provenance + ) + + if job_progress is not None: + job_progress.update_progress(100) + + return file_id + + def add_create_view_attachment_job( + self, + dataset_id: int, + dataset_type: str, + description: str, + provenance: Dict[str, Any], + status: Optional[Iterable[RecordStatusEnum]] = None, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + *, + include_children: bool = True, + session: Optional[Session] = None, + ) -> int: + """ + Creates an internal job for creating and attaching a view to a dataset + + See :ref:`create_view_attachment` for a description of the parameters + + Returns + ------- + : + ID of the created internal job + """ + + if not self.root_socket.qcf_config.s3.enabled: + raise UserReportableError("S3 storage is not enabled. Can not not create view") + + return self.root_socket.internal_jobs.add( + f"create_attach_view_ds_{dataset_id}", + now_at_utc(), + f"datasets.create_view_attachment", + { + "dataset_id": dataset_id, + "dataset_type": dataset_type, + "description": description, + "provenance": provenance, + "status": status, + "include": include, + "exclude": exclude, + "include_children": include_children, + }, + user_id=None, + unique_name=True, + session=session, + ) diff --git a/qcportal/qcportal/dataset_models.py b/qcportal/qcportal/dataset_models.py index 9685b9a6b..ceaf2f498 100644 --- a/qcportal/qcportal/dataset_models.py +++ b/qcportal/qcportal/dataset_models.py @@ -3,6 +3,7 @@ import math import os from datetime import datetime +from enum import Enum from typing import ( TYPE_CHECKING, Optional, @@ -30,17 +31,31 @@ from tqdm import tqdm from qcportal.base_models import RestModelBase, validate_list_to_single, CommonBulkGetBody -from qcportal.metadata_models import DeleteMetadata -from qcportal.metadata_models import InsertMetadata +from qcportal.internal_jobs import InternalJob +from qcportal.metadata_models import DeleteMetadata, InsertMetadata from qcportal.record_models import PriorityEnum, RecordStatusEnum, BaseRecord from qcportal.utils import make_list, chunk_iterable from qcportal.cache import DatasetCache, read_dataset_metadata, get_records_with_cache +from qcportal.external_files import ExternalFile if TYPE_CHECKING: from qcportal.client import PortalClient from pandas import DataFrame +class DatasetAttachmentType(str, Enum): + """ + The type of attachment a file is for a dataset + """ + + other = "other" + view = "view" + + +class DatasetAttachment(ExternalFile): + attached_type = DatasetAttachmentType + + class Citation(BaseModel): """A literature citation.""" @@ -104,6 +119,8 @@ class Config: metadata: Dict[str, Any] extras: Dict[str, Any] + attachments: List[DatasetAttachment] + ######################################## # Caches of information ######################################## @@ -320,6 +337,63 @@ def submit( "post", f"api/v1/datasets/{self.dataset_type}/{self.id}/submit", Any, body=body_data ) + def create_view( + self, + description: str, + provenance: Dict[str, Any], + status: Optional[Iterable[RecordStatusEnum]] = None, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + *, + include_children: bool = True, + ) -> InternalJob: + """ + Creates a view of this dataset on the server + + This function will return an :ref:`InternalJob` which can be used to watch + for completion if desired. The job will run server side without user interaction. + + Note the ID field of the object if you with to retrieve this internal job later + (via :ref:`PortalClient.get_internal_job`) + + Parameters + ---------- + description + Optional string describing the view file + provenance + Dictionary with any metadata or other information about the view. Information regarding + the options used to create the view will be added. + status + List of statuses to include. Default is to include records with any status + include + List of specific record fields to include in the export. Default is to include most fields + exclude + List of specific record fields to exclude from the export. Defaults to excluding none. + include_children + Specifies whether child records associated with the main records should also be included (recursively) + in the view file. + + Returns + ------- + : + An :ref:`InternalJob` object which can be used to watch for completion. + """ + + body = DatasetCreateViewBody( + description=description, + provenance=provenance, + status=status, + include=include, + exclude=exclude, + include_children=include_children, + ) + + job_id = self._client.make_request( + "post", f"api/v1/datasets/{self.dataset_type}/{self.id}/create_view", int, body=body + ) + + return self._client.get_internal_job(job_id) + ######################################### # Various properties and getters/setters ######################################### @@ -1786,6 +1860,15 @@ class DatasetFetchRecordsBody(RestModelBase): status: Optional[List[RecordStatusEnum]] = None +class DatasetCreateViewBody(RestModelBase): + description: Optional[str] + provenance: Dict[str, Any] + status: Optional[List[RecordStatusEnum]] = (None,) + include: Optional[List[str]] = (None,) + exclude: Optional[List[str]] = (None,) + include_children: bool = (True,) + + class DatasetSubmitBody(RestModelBase): entry_names: Optional[List[str]] = None specification_names: Optional[List[str]] = None From c8dd52310fa3986d8dfaca07d458dac94277bca8 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 6 Jan 2025 15:13:00 -0500 Subject: [PATCH 12/26] Some more helper functions in the cache classes --- qcportal/qcportal/cache.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/qcportal/qcportal/cache.py b/qcportal/qcportal/cache.py index 2111204cb..bc9b474b6 100644 --- a/qcportal/qcportal/cache.py +++ b/qcportal/qcportal/cache.py @@ -540,21 +540,37 @@ def __init__(self, server_uri: str, cache_dir: Optional[str], max_size: int): self.cache_dir = None + def get_cache_path(self, cache_name: str) -> str: + if not self._is_disk: + raise RuntimeError("Cannot get path to cache for memory-only cache") + + return os.path.join(self.cache_dir, f"{cache_name}.sqlite") + def get_cache_uri(self, cache_name: str) -> str: if self._is_disk: - file_path = os.path.join(self.cache_dir, f"{cache_name}.sqlite") + file_path = self.get_cache_path(cache_name) uri = f"file:{file_path}" else: uri = ":memory:" return uri + def get_dataset_cache_path(self, dataset_id: int) -> str: + return self.get_cache_path(f"dataset_{dataset_id}") + + def get_dataset_cache_uri(self, dataset_id: int) -> str: + return self.get_cache_uri(f"dataset_{dataset_id}") + def get_dataset_cache(self, dataset_id: int, dataset_type: Type[_DATASET_T]) -> DatasetCache: - uri = self.get_cache_uri(f"dataset_{dataset_id}") + uri = self.get_dataset_cache_uri(dataset_id) # If you are asking this for a dataset cache, it should be writable return DatasetCache(uri, False, dataset_type) + @property + def is_disk(self) -> bool: + return self._is_disk + def vacuum(self, cache_name: Optional[str] = None): if self._is_disk: # TODO From 9543608d5917879fa4e691d7f87120e8d2c7edfc Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 6 Jan 2025 15:13:45 -0500 Subject: [PATCH 13/26] Add functionality for downloading & using dataset views --- .../qcfractal/components/dataset_db_models.py | 1 - .../qcfractal/components/dataset_routes.py | 23 +++- .../qcfractal/components/dataset_socket.py | 12 ++ qcportal/qcportal/dataset_models.py | 129 +++++++++++++++++- 4 files changed, 156 insertions(+), 9 deletions(-) diff --git a/qcfractal/qcfractal/components/dataset_db_models.py b/qcfractal/qcfractal/components/dataset_db_models.py index 468f04d11..bf3c937ca 100644 --- a/qcfractal/qcfractal/components/dataset_db_models.py +++ b/qcfractal/qcfractal/components/dataset_db_models.py @@ -85,7 +85,6 @@ class BaseDatasetORM(BaseORM): "DatasetAttachmentORM", cascade="all, delete-orphan", passive_deletes=True, - lazy="selectin", ) __table_args__ = ( diff --git a/qcfractal/qcfractal/components/dataset_routes.py b/qcfractal/qcfractal/components/dataset_routes.py index 59c7f4521..62a0fe56f 100644 --- a/qcfractal/qcfractal/components/dataset_routes.py +++ b/qcfractal/qcfractal/components/dataset_routes.py @@ -38,7 +38,14 @@ def get_general_dataset_v1(dataset_id: int, url_params: ProjURLParameters): with storage_socket.session_scope(True) as session: ds_type = storage_socket.datasets.lookup_type(dataset_id, session=session) ds_socket = storage_socket.datasets.get_socket(ds_type) - return ds_socket.get(dataset_id, url_params.include, url_params.exclude, session=session) + + r = ds_socket.get(dataset_id, url_params.include, url_params.exclude, session=session) + + # TODO - remove this eventually + # Don't return attachments by default + r.pop("attachments", None) + + return r @api_v1.route("/datasets/query", methods=["POST"]) @@ -163,7 +170,7 @@ def modify_dataset_metadata_v1(dataset_type: str, dataset_id: int, body_data: Da ######################### -# Views +# Views & Attachments ######################### @api_v1.route("/datasets///create_view", methods=["POST"]) @wrap_route("WRITE") @@ -361,10 +368,16 @@ def revert_dataset_records_v1(dataset_type: str, dataset_id: int, body_data: Dat ) -################### -# Contributed Values -################### +################################# +# Fields not returned by default +################################# @api_v1.route("/datasets//contributed_values", methods=["GET"]) @wrap_route("READ") def fetch_dataset_contributed_values_v1(dataset_id: int): return storage_socket.datasets.get_contributed_values(dataset_id) + + +@api_v1.route("/datasets//attachments", methods=["GET"]) +@wrap_route("READ") +def fetch_dataset_attachments_v1(dataset_id: int): + return storage_socket.datasets.get_attachments(dataset_id) diff --git a/qcfractal/qcfractal/components/dataset_socket.py b/qcfractal/qcfractal/components/dataset_socket.py index 4786b8ae0..638373516 100644 --- a/qcfractal/qcfractal/components/dataset_socket.py +++ b/qcfractal/qcfractal/components/dataset_socket.py @@ -1660,6 +1660,18 @@ def get_contributed_values(self, dataset_id: int, *, session: Optional[Session] cv = session.execute(stmt).scalars().all() return [x.model_dict() for x in cv] + def get_attachments(self, dataset_id: int, *, session: Optional[Session] = None) -> List[Dict[str, Any]]: + """ + Get the attachments for a dataset + """ + + stmt = select(DatasetAttachmentORM) + stmt = stmt.where(DatasetAttachmentORM.dataset_id == dataset_id) + + with self.root_socket.optional_session(session, True) as session: + att = session.execute(stmt).scalars().all() + return [x.model_dict() for x in att] + def attach_file( self, dataset_id: int, diff --git a/qcportal/qcportal/dataset_models.py b/qcportal/qcportal/dataset_models.py index ceaf2f498..d9bbfe8a7 100644 --- a/qcportal/qcportal/dataset_models.py +++ b/qcportal/qcportal/dataset_models.py @@ -53,7 +53,7 @@ class DatasetAttachmentType(str, Enum): class DatasetAttachment(ExternalFile): - attached_type = DatasetAttachmentType + attachment_type: DatasetAttachmentType class Citation(BaseModel): @@ -119,8 +119,6 @@ class Config: metadata: Dict[str, Any] extras: Dict[str, Any] - attachments: List[DatasetAttachment] - ######################################## # Caches of information ######################################## @@ -134,6 +132,7 @@ class Config: # Fields not always included when fetching the dataset ###################################################### contributed_values_: Optional[Dict[str, ContributedValues]] = Field(None, alias="contributed_values") + attachments_: Optional[List[DatasetAttachment]] = Field(None, alias="attachments") ############################# # Private non-pydantic fields @@ -337,6 +336,130 @@ def submit( "post", f"api/v1/datasets/{self.dataset_type}/{self.id}/submit", Any, body=body_data ) + ######################################### + # Attachments + ######################################### + def fetch_attachments(self): + self.assert_is_not_view() + self.assert_online() + + self.attachments_ = self._client.make_request( + "get", + f"api/v1/datasets/{self.id}/attachments", + Optional[List[DatasetAttachment]], + ) + + @property + def attachments(self) -> List[DatasetAttachment]: + if not self.attachments_: + self.fetch_attachments() + + return self.attachments_ + + ######################################### + # View creation and use + ######################################### + def list_views(self): + return [x for x in self.attachments if x.attachment_type == DatasetAttachmentType.view] + + def download_view( + self, + view_file_id: Optional[int] = None, + destination_path: Optional[str] = None, + overwrite: bool = True, + ): + """ + Downloads a view for this dataset + + If a `view_file_id` is not given, the most recent view will be downloaded. + + If destination path is not given, the file will be placed in the current directory, and the + filename determined by what is stored on the server. + + Parameters + ---------- + view_file_id + ID of the view to download. See :ref:`list_views`. If `None`, will download the latest view + destination_path + Full path to the destination file (including filename) + overwrite + If True, any existing file will be overwritten + """ + + my_views = self.list_views() + + if not my_views: + raise ValueError(f"No views available for this dataset") + + if view_file_id is None: + latest_view_ids = max(my_views, key=lambda x: x.created_on) + view_file_id = latest_view_ids.id + + view_map = {x.id: x for x in self.list_views()} + if view_file_id not in view_map: + raise ValueError(f"File id {view_file_id} is not a valid ID for this dataset") + + if destination_path is None: + view_data = view_map[view_file_id] + destination_path = os.path.join(os.getcwd(), view_data.file_name) + + self._client.download_external_file(view_file_id, destination_path, overwrite=overwrite) + + def use_view_cache( + self, + view_file_path: str, + ): + """ + Downloads and loads a view for this dataset + + Parameters + ---------- + view_file_path + Full path to the view file + """ + + cache_uri = f"file:{view_file_path}" + dcache = DatasetCache(cache_uri=cache_uri, read_only=False, dataset_type=type(self)) + + meta = dcache.get_metadata("dataset_metadata") + + if meta["id"] != self.id: + raise ValueError( + f"Info in view file does not match this dataset. ID in the file {meta['id']}, ID of this dataset {self.id}" + ) + + if meta["dataset_type"] != self.dataset_type: + raise ValueError( + f"Info in view file does not match this dataset. Dataset type in the file {meta['dataset_type']}, dataset type of this dataset {self.dataset_type}" + ) + + if meta["name"] != self.name: + raise ValueError( + f"Info in view file does not match this dataset. Dataset name in the file {meta['name']}, name of this dataset {self.name}" + ) + + self._cache_data = dcache + + def preload_cache(self, view_file_id: Optional[int] = None): + """ + Downloads a view file and uses it as the current cache + + Parameters + ---------- + view_file_id + ID of the view to download. See :ref:`list_views`. If `None`, will download the latest view + """ + + self.assert_is_not_view() + self.assert_online() + + if not self._client.cache.is_disk: + raise RuntimeError("Caching to disk is not enabled. Set the cache_dir path when constructing the client") + + destination_path = self._client.cache.get_dataset_cache_path(self.id) + self.download_view(view_file_id=view_file_id, destination_path=destination_path, overwrite=True) + self.use_view_cache(destination_path) + def create_view( self, description: str, From 163cdcea2f968ad1459e31f3ad69f198b321ae17 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Thu, 9 Jan 2025 16:47:48 -0500 Subject: [PATCH 14/26] Add dataset to internal job tracking --- ...-5f6f804e11d3_add_dataset_internal_jobs.py | 35 ++++++++ .../qcfractal/components/dataset_db_models.py | 8 ++ .../qcfractal/components/dataset_routes.py | 16 ++++ .../qcfractal/components/dataset_socket.py | 88 ++++++++++++++----- qcportal/qcportal/dataset_models.py | 36 +++++++- qcportal/qcportal/internal_jobs/models.py | 11 ++- 6 files changed, 168 insertions(+), 26 deletions(-) create mode 100644 qcfractal/qcfractal/alembic/versions/2025-01-09-5f6f804e11d3_add_dataset_internal_jobs.py diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-09-5f6f804e11d3_add_dataset_internal_jobs.py b/qcfractal/qcfractal/alembic/versions/2025-01-09-5f6f804e11d3_add_dataset_internal_jobs.py new file mode 100644 index 000000000..647260f50 --- /dev/null +++ b/qcfractal/qcfractal/alembic/versions/2025-01-09-5f6f804e11d3_add_dataset_internal_jobs.py @@ -0,0 +1,35 @@ +"""Add dataset internal jobs + +Revision ID: 5f6f804e11d3 +Revises: 84285e3620fd +Create Date: 2025-01-09 16:25:50.187495 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5f6f804e11d3" +down_revision = "84285e3620fd" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "dataset_internal_job", + sa.Column("internal_job_id", sa.Integer(), nullable=False), + sa.Column("dataset_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["dataset_id"], ["base_dataset.id"], ondelete="cascade"), + sa.ForeignKeyConstraint(["internal_job_id"], ["internal_jobs.id"], ondelete="cascade"), + sa.PrimaryKeyConstraint("internal_job_id", "dataset_id"), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("dataset_internal_job") + # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/components/dataset_db_models.py b/qcfractal/qcfractal/components/dataset_db_models.py index bf3c937ca..173fd7748 100644 --- a/qcfractal/qcfractal/components/dataset_db_models.py +++ b/qcfractal/qcfractal/components/dataset_db_models.py @@ -20,6 +20,7 @@ from qcfractal.components.auth.db_models import UserIDMapSubquery, GroupIDMapSubquery, UserORM, GroupORM from qcfractal.components.external_files.db_models import ExternalFileORM +from qcfractal.components.internal_jobs.db_models import InternalJobORM from qcfractal.db_socket import BaseORM, MsgpackExt from qcportal.dataset_models import DatasetAttachmentType @@ -140,6 +141,13 @@ class ContributedValuesORM(BaseORM): _qcportal_model_excludes = ["dataset_id"] +class DatasetInternalJobORM(BaseORM): + __tablename__ = "dataset_internal_job" + + internal_job_id = Column(Integer, ForeignKey(InternalJobORM.id, ondelete="cascade"), primary_key=True) + dataset_id = Column(Integer, ForeignKey("base_dataset.id", ondelete="cascade"), primary_key=True) + + class DatasetAttachmentORM(ExternalFileORM): __tablename__ = "dataset_attachment" diff --git a/qcfractal/qcfractal/components/dataset_routes.py b/qcfractal/qcfractal/components/dataset_routes.py index 62a0fe56f..a3ee0da3d 100644 --- a/qcfractal/qcfractal/components/dataset_routes.py +++ b/qcfractal/qcfractal/components/dataset_routes.py @@ -22,6 +22,7 @@ DatasetQueryRecords, DatasetDeleteParams, DatasetModifyEntryBody, + DatasetGetInternalJobParams, ) from qcportal.exceptions import LimitExceededError @@ -368,6 +369,21 @@ def revert_dataset_records_v1(dataset_type: str, dataset_id: int, body_data: Dat ) +################################# +# Internal Jobs +################################# +@api_v1.route("/datasets//internal_jobs/", methods=["GET"]) +@wrap_route("READ") +def get_dataset_internal_job_v1(dataset_id: int, job_id: int): + return storage_socket.datasets.get_internal_job(dataset_id, job_id) + + +@api_v1.route("/datasets//internal_jobs", methods=["GET"]) +@wrap_route("READ") +def list_dataset_internal_jobs_v1(dataset_id: int, url_params: DatasetGetInternalJobParams): + return storage_socket.datasets.list_internal_jobs(dataset_id, status=url_params.status) + + ################################# # Fields not returned by default ################################# diff --git a/qcfractal/qcfractal/components/dataset_socket.py b/qcfractal/qcfractal/components/dataset_socket.py index 638373516..1c6c84056 100644 --- a/qcfractal/qcfractal/components/dataset_socket.py +++ b/qcfractal/qcfractal/components/dataset_socket.py @@ -9,7 +9,14 @@ from sqlalchemy.orm import load_only, lazyload, joinedload, with_polymorphic from sqlalchemy.orm.attributes import flag_modified -from qcfractal.components.dataset_db_models import BaseDatasetORM, ContributedValuesORM, DatasetAttachmentORM +from qcfractal.components.dataset_db_models import ( + BaseDatasetORM, + ContributedValuesORM, + DatasetAttachmentORM, + DatasetInternalJobORM, +) +from qcfractal.components.dataset_processing import create_view_file +from qcfractal.components.internal_jobs.db_models import InternalJobORM from qcfractal.components.record_db_models import BaseRecordORM from qcfractal.db_socket.helpers import ( get_general, @@ -17,10 +24,10 @@ ) from qcportal.dataset_models import DatasetAttachmentType from qcportal.exceptions import AlreadyExistsError, MissingDataError, UserReportableError +from qcportal.internal_jobs import InternalJobStatusEnum from qcportal.metadata_models import InsertMetadata, DeleteMetadata, UpdateMetadata from qcportal.record_models import RecordStatusEnum, PriorityEnum from qcportal.utils import chunk_iterable, now_at_utc -from qcfractal.components.dataset_processing import create_view_file if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -1660,6 +1667,42 @@ def get_contributed_values(self, dataset_id: int, *, session: Optional[Session] cv = session.execute(stmt).scalars().all() return [x.model_dict() for x in cv] + def get_internal_job( + self, + dataset_id: int, + job_id: int, + *, + session: Optional[Session] = None, + ): + stmt = select(InternalJobORM) + stmt = stmt.join(DatasetInternalJobORM, DatasetInternalJobORM.internal_job_id == InternalJobORM.id) + stmt = stmt.where(DatasetInternalJobORM.dataset_id == dataset_id) + stmt = stmt.where(DatasetInternalJobORM.internal_job_id == job_id) + + with self.root_socket.optional_session(session, True) as session: + ij_orm = session.execute(stmt).scalar_one_or_none() + if ij_orm is None: + raise MissingDataError(f"Job id {job_id} not found in dataset {dataset_id}") + return ij_orm.model_dict() + + def list_internal_jobs( + self, + dataset_id: int, + status: Optional[Iterable[InternalJobStatusEnum]] = None, + *, + session: Optional[Session] = None, + ): + stmt = select(InternalJobORM) + stmt = stmt.join(DatasetInternalJobORM, DatasetInternalJobORM.internal_job_id == InternalJobORM.id) + stmt = stmt.where(DatasetInternalJobORM.dataset_id == dataset_id) + + if status is not None: + stmt = stmt.where(InternalJobORM.status.in_(status)) + + with self.root_socket.optional_session(session, True) as session: + ij_orm = session.execute(stmt).scalars().all() + return [i.model_dict() for i in ij_orm] + def get_attachments(self, dataset_id: int, *, session: Optional[Session] = None) -> List[Dict[str, Any]]: """ Get the attachments for a dataset @@ -1850,21 +1893,26 @@ def add_create_view_attachment_job( if not self.root_socket.qcf_config.s3.enabled: raise UserReportableError("S3 storage is not enabled. Can not not create view") - return self.root_socket.internal_jobs.add( - f"create_attach_view_ds_{dataset_id}", - now_at_utc(), - f"datasets.create_view_attachment", - { - "dataset_id": dataset_id, - "dataset_type": dataset_type, - "description": description, - "provenance": provenance, - "status": status, - "include": include, - "exclude": exclude, - "include_children": include_children, - }, - user_id=None, - unique_name=True, - session=session, - ) + with self.root_socket.optional_session(session) as session: + job_id = self.root_socket.internal_jobs.add( + f"create_attach_view_ds_{dataset_id}", + now_at_utc(), + f"datasets.create_view_attachment", + { + "dataset_id": dataset_id, + "dataset_type": dataset_type, + "description": description, + "provenance": provenance, + "status": status, + "include": include, + "exclude": exclude, + "include_children": include_children, + }, + user_id=None, + unique_name=True, + session=session, + ) + + ds_job_orm = DatasetInternalJobORM(dataset_id=dataset_id, internal_job_id=job_id) + session.add(ds_job_orm) + return job_id diff --git a/qcportal/qcportal/dataset_models.py b/qcportal/qcportal/dataset_models.py index d9bbfe8a7..84cf2c062 100644 --- a/qcportal/qcportal/dataset_models.py +++ b/qcportal/qcportal/dataset_models.py @@ -31,7 +31,7 @@ from tqdm import tqdm from qcportal.base_models import RestModelBase, validate_list_to_single, CommonBulkGetBody -from qcportal.internal_jobs import InternalJob +from qcportal.internal_jobs import InternalJob, InternalJobStatusEnum from qcportal.metadata_models import DeleteMetadata, InsertMetadata from qcportal.record_models import PriorityEnum, RecordStatusEnum, BaseRecord from qcportal.utils import make_list, chunk_iterable @@ -336,6 +336,32 @@ def submit( "post", f"api/v1/datasets/{self.dataset_type}/{self.id}/submit", Any, body=body_data ) + ######################################### + # Internal jobs + ######################################### + def get_internal_job(self, job_id: int) -> InternalJob: + self.assert_is_not_view() + self.assert_online() + + ij_dict = self._client.make_request("get", f"/api/v1/datasets/{self.id}/internal_jobs/{job_id}", Dict[str, Any]) + refresh_url = f"/api/v1/datasets/{self.id}/internal_jobs/{ij_dict['id']}" + return InternalJob(client=self._client, refresh_url=refresh_url, **ij_dict) + + def list_internal_jobs( + self, status: Optional[Union[InternalJobStatusEnum, Iterable[InternalJobStatusEnum]]] = None + ) -> List[InternalJob]: + self.assert_is_not_view() + self.assert_online() + + url_params = DatasetGetInternalJobParams(status=make_list(status)) + ij_dicts = self._client.make_request( + "get", f"/api/v1/datasets/{self.id}/internal_jobs", List[Dict[str, Any]], url_params=url_params + ) + return [ + InternalJob(client=self._client, refresh_url=f"/api/v1/datasets/{self.id}/internal_jobs/{ij['id']}", **ij) + for ij in ij_dicts + ] + ######################################### # Attachments ######################################### @@ -477,7 +503,7 @@ def create_view( for completion if desired. The job will run server side without user interaction. Note the ID field of the object if you with to retrieve this internal job later - (via :ref:`PortalClient.get_internal_job`) + (via :ref:`PortalClient.get_internal_job` and :ref:`get_internal_jobs`) Parameters ---------- @@ -515,7 +541,7 @@ def create_view( "post", f"api/v1/datasets/{self.dataset_type}/{self.id}/create_view", int, body=body ) - return self._client.get_internal_job(job_id) + return self.get_internal_job(job_id) ######################################### # Various properties and getters/setters @@ -2037,6 +2063,10 @@ class DatasetModifyEntryBody(RestModelBase): overwrite_attributes: bool = False +class DatasetGetInternalJobParams(RestModelBase): + status: Optional[List[InternalJobStatusEnum]] = None + + def dataset_from_dict(data: Dict[str, Any], client: Any, cache_data: Optional[DatasetCache] = None) -> BaseDataset: """ Create a dataset object from a datamodel diff --git a/qcportal/qcportal/internal_jobs/models.py b/qcportal/qcportal/internal_jobs/models.py index 232f6a263..1ab1dd7cd 100644 --- a/qcportal/qcportal/internal_jobs/models.py +++ b/qcportal/qcportal/internal_jobs/models.py @@ -4,7 +4,6 @@ from typing import Optional, Dict, Any, List, Union from dateutil.parser import parse as date_parser -from rich.jupyter import display try: from pydantic.v1 import BaseModel, Extra, validator, PrivateAttr @@ -70,10 +69,12 @@ class Config: user: Optional[str] _client: Any = PrivateAttr(None) + _refresh_url: Optional[str] = PrivateAttr(None) - def __init__(self, client=None, **kwargs): + def __init__(self, client=None, refresh_url=None, **kwargs): BaseModel.__init__(self, **kwargs) self._client = client + self._refresh_url = refresh_url def refresh(self): """ @@ -83,7 +84,11 @@ def refresh(self): if self._client is None: raise RuntimeError("Client is not set") - server_data = self._client.get_internal_job(self.id) + if self._refresh_url is None: + server_data = self._client.get_internal_job(self.id) + else: + server_data = self._client.make_request("get", self._refresh_url, InternalJob) + for k, v in server_data: setattr(self, k, v) From a7ee70c85decf36a9473d8b68211a30e574287be Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Thu, 9 Jan 2025 19:08:27 -0500 Subject: [PATCH 15/26] Enable external file deletion --- .../qcfractal/components/dataset_routes.py | 23 +++++--- .../qcfractal/components/dataset_socket.py | 13 +++++ .../components/external_files/socket.py | 55 ++++++++++++------- qcportal/qcportal/dataset_models.py | 12 ++++ 4 files changed, 75 insertions(+), 28 deletions(-) diff --git a/qcfractal/qcfractal/components/dataset_routes.py b/qcfractal/qcfractal/components/dataset_routes.py index a3ee0da3d..d53266eee 100644 --- a/qcfractal/qcfractal/components/dataset_routes.py +++ b/qcfractal/qcfractal/components/dataset_routes.py @@ -385,15 +385,24 @@ def list_dataset_internal_jobs_v1(dataset_id: int, url_params: DatasetGetInterna ################################# -# Fields not returned by default +# Attachments ################################# -@api_v1.route("/datasets//contributed_values", methods=["GET"]) -@wrap_route("READ") -def fetch_dataset_contributed_values_v1(dataset_id: int): - return storage_socket.datasets.get_contributed_values(dataset_id) - - @api_v1.route("/datasets//attachments", methods=["GET"]) @wrap_route("READ") def fetch_dataset_attachments_v1(dataset_id: int): return storage_socket.datasets.get_attachments(dataset_id) + + +@api_v1.route("/datasets//attachments/", methods=["DELETE"]) +@wrap_route("DELETE") +def delete_dataset_attachment_v1(dataset_id: int, attachment_id: int): + return storage_socket.datasets.delete_attachment(dataset_id, attachment_id) + + +################################# +# Contributed Values +################################# +@api_v1.route("/datasets//contributed_values", methods=["GET"]) +@wrap_route("READ") +def fetch_dataset_contributed_values_v1(dataset_id: int): + return storage_socket.datasets.get_contributed_values(dataset_id) diff --git a/qcfractal/qcfractal/components/dataset_socket.py b/qcfractal/qcfractal/components/dataset_socket.py index 1c6c84056..e042a5742 100644 --- a/qcfractal/qcfractal/components/dataset_socket.py +++ b/qcfractal/qcfractal/components/dataset_socket.py @@ -1715,6 +1715,19 @@ def get_attachments(self, dataset_id: int, *, session: Optional[Session] = None) att = session.execute(stmt).scalars().all() return [x.model_dict() for x in att] + def delete_attachment(self, dataset_id: int, file_id: int, *, session: Optional[Session] = None): + stmt = select(DatasetAttachmentORM) + stmt = stmt.where(DatasetAttachmentORM.dataset_id == dataset_id) + stmt = stmt.where(DatasetAttachmentORM.id == file_id) + stmt = stmt.with_for_update() + + with self.root_socket.optional_session(session) as session: + att = session.execute(stmt).scalar_one_or_none() + if att is None: + raise MissingDataError(f"Attachment with file id {file_id} not found in dataset {dataset_id}") + + return self.root_socket.external_files.delete(file_id, session=session) + def attach_file( self, dataset_id: int, diff --git a/qcfractal/qcfractal/components/external_files/socket.py b/qcfractal/qcfractal/components/external_files/socket.py index eade40799..309eacfc1 100644 --- a/qcfractal/qcfractal/components/external_files/socket.py +++ b/qcfractal/qcfractal/components/external_files/socket.py @@ -17,6 +17,7 @@ from qcportal.exceptions import MissingDataError from qcportal.external_files import ExternalFileTypeEnum, ExternalFileStatusEnum from .db_models import ExternalFileORM +from sqlalchemy import select if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -215,27 +216,39 @@ def get_metadata( return ef.model_dict() - # def delete(self, file_id: int, *, session: Optional[Session] = None): - # """ - # Deletes an external file from the database and from remote storage - - # Parameters - # ---------- - # file_id - # ID of the external file to remove - # session - # An existing SQLAlchemy session to use. If None, one will be created. If an existing session - # is used, it will be flushed (but not committed) before returning from this function. - - # Returns - # ------- - # : - # Metadata about what was deleted and any errors that occurred - # """ - - # with self.root_socket.optional_session(session) as session: - # stmt = delete(ExternalFileORM).where(ExternalFileORM.id == file_id) - # session.execute(stmt) + def delete(self, file_id: int, *, session: Optional[Session] = None): + """ + Deletes an external file from the database and from remote storage + + Parameters + ---------- + file_id + ID of the external file to remove + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + Metadata about what was deleted and any errors that occurred + """ + + with self.root_socket.optional_session(session) as session: + stmt = select(ExternalFileORM).where(ExternalFileORM.id == file_id) + ef_orm = session.execute(stmt).scalar_one_or_none() + + if ef_orm is None: + raise MissingDataError(f"Cannot find external file with id {file_id} in the database") + + bucket = ef_orm.bucket + object_key = ef_orm.object_key + + session.delete(ef_orm) + session.flush() + + # now delete from S3 storage + self._s3_client.delete_object(Bucket=bucket, Key=object_key) def get_url(self, file_id: int, *, session: Optional[Session] = None) -> Tuple[str, str]: """ diff --git a/qcportal/qcportal/dataset_models.py b/qcportal/qcportal/dataset_models.py index 84cf2c062..62b162efd 100644 --- a/qcportal/qcportal/dataset_models.py +++ b/qcportal/qcportal/dataset_models.py @@ -382,6 +382,18 @@ def attachments(self) -> List[DatasetAttachment]: return self.attachments_ + def delete_attachment(self, file_id: int): + self.assert_is_not_view() + self.assert_online() + + self._client.make_request( + "delete", + f"api/v1/datasets/{self.id}/attachments/{file_id}", + None, + ) + + self.fetch_attachments() + ######################################### # View creation and use ######################################### From 5b0b97ff78f79b8fbe45bfba22c8a6db7073ca47 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Fri, 10 Jan 2025 10:16:06 -0500 Subject: [PATCH 16/26] Add serial groups to internal jobs --- ...0c677f8d1_add_internal_job_serial_group.py | 40 ++++++++++++++++ .../qcfractal/components/dataset_socket.py | 10 +++- .../components/internal_jobs/db_models.py | 11 +++++ .../components/internal_jobs/socket.py | 48 +++++++++++++++++-- .../components/internal_jobs/test_socket.py | 44 +++++++++++++++++ qcportal/qcportal/internal_jobs/models.py | 1 + 6 files changed, 149 insertions(+), 5 deletions(-) create mode 100644 qcfractal/qcfractal/alembic/versions/2025-01-10-3690c677f8d1_add_internal_job_serial_group.py diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-10-3690c677f8d1_add_internal_job_serial_group.py b/qcfractal/qcfractal/alembic/versions/2025-01-10-3690c677f8d1_add_internal_job_serial_group.py new file mode 100644 index 000000000..9b4cdd8d4 --- /dev/null +++ b/qcfractal/qcfractal/alembic/versions/2025-01-10-3690c677f8d1_add_internal_job_serial_group.py @@ -0,0 +1,40 @@ +"""Add internal job serial group + +Revision ID: 3690c677f8d1 +Revises: 5f6f804e11d3 +Create Date: 2025-01-10 16:08:36.541807 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "3690c677f8d1" +down_revision = "5f6f804e11d3" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("internal_jobs", sa.Column("serial_group", sa.String(), nullable=True)) + op.create_index( + "ux_internal_jobs_status_serial_group", + "internal_jobs", + ["status", "serial_group"], + unique=True, + postgresql_where=sa.text("status = 'running'"), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + "ux_internal_jobs_status_serial_group", + table_name="internal_jobs", + postgresql_where=sa.text("status = 'running'"), + ) + op.drop_column("internal_jobs", "serial_group") + # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/components/dataset_socket.py b/qcfractal/qcfractal/components/dataset_socket.py index e042a5742..97df06157 100644 --- a/qcfractal/qcfractal/components/dataset_socket.py +++ b/qcfractal/qcfractal/components/dataset_socket.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING from sqlalchemy import select, delete, func, union, text, and_ +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import load_only, lazyload, joinedload, with_polymorphic from sqlalchemy.orm.attributes import flag_modified @@ -1923,9 +1924,14 @@ def add_create_view_attachment_job( }, user_id=None, unique_name=True, + serial_group="ds_create_view", session=session, ) - ds_job_orm = DatasetInternalJobORM(dataset_id=dataset_id, internal_job_id=job_id) - session.add(ds_job_orm) + stmt = ( + insert(DatasetInternalJobORM) + .values(dataset_id=dataset_id, internal_job_id=job_id) + .on_conflict_do_nothing() + ) + session.execute(stmt) return job_id diff --git a/qcfractal/qcfractal/components/internal_jobs/db_models.py b/qcfractal/qcfractal/components/internal_jobs/db_models.py index 1512f822d..c9184c7cf 100644 --- a/qcfractal/qcfractal/components/internal_jobs/db_models.py +++ b/qcfractal/qcfractal/components/internal_jobs/db_models.py @@ -56,6 +56,9 @@ class InternalJobORM(BaseORM): # it must be unique. null != null always unique_name = Column(String, nullable=True) + # If this job is part of a serial group (only one may run at a time) + serial_group = Column(String, nullable=True) + __table_args__ = ( Index("ix_internal_jobs_added_date", "added_date", postgresql_using="brin"), Index("ix_internal_jobs_scheduled_date", "scheduled_date", postgresql_using="brin"), @@ -64,6 +67,14 @@ class InternalJobORM(BaseORM): Index("ix_internal_jobs_name", "name"), Index("ix_internal_jobs_user_id", "user_id"), UniqueConstraint("unique_name", name="ux_internal_jobs_unique_name"), + # Enforces only one running per serial group + Index( + "ux_internal_jobs_status_serial_group", + "status", + "serial_group", + unique=True, + postgresql_where=(status == InternalJobStatusEnum.running), + ), ) _qcportal_model_excludes = ["unique_name", "user_id"] diff --git a/qcfractal/qcfractal/components/internal_jobs/socket.py b/qcfractal/qcfractal/components/internal_jobs/socket.py index 74d2969da..32fcb3d2c 100644 --- a/qcfractal/qcfractal/components/internal_jobs/socket.py +++ b/qcfractal/qcfractal/components/internal_jobs/socket.py @@ -13,6 +13,7 @@ import psycopg2.extensions from sqlalchemy import select, delete, update, and_, or_ from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.exc import IntegrityError from qcfractal.components.auth.db_models import UserIDMapSubquery from qcfractal.db_socket.helpers import get_query_proj_options @@ -71,6 +72,7 @@ def add( after_function: Optional[str] = None, after_function_kwargs: Optional[Dict[str, Any]] = None, repeat_delay: Optional[int] = None, + serial_group: Optional[str] = None, *, session: Optional[Session] = None, ) -> int: @@ -99,6 +101,8 @@ def add( Arguments to use when calling `after_function` repeat_delay If set, will submit a new, identical job to be run repeat_delay seconds after this one finishes + serial_group + Only one job within this group may be run at once. If None, there is no limit session An existing SQLAlchemy session to use. If None, one will be created. If an existing session is used, it will be flushed (but not committed) before returning from this function. @@ -122,6 +126,7 @@ def add( after_function=after_function, after_function_kwargs=after_function_kwargs, repeat_delay=repeat_delay, + serial_group=serial_group, user_id=user_id, ) stmt = stmt.on_conflict_do_update( @@ -153,6 +158,7 @@ def add( after_function=after_function, after_function_kwargs=after_function_kwargs, repeat_delay=repeat_delay, + serial_group=serial_group, user_id=user_id, ) if unique_name: @@ -413,8 +419,16 @@ def _wait_for_job(session: Session, logger, conn, end_event): Blocks until a job is possibly available to run """ + serial_cte = select(InternalJobORM.serial_group).distinct() + serial_cte = serial_cte.where(InternalJobORM.status == InternalJobStatusEnum.running) + serial_cte = serial_cte.where(InternalJobORM.serial_group.is_not(None)) + serial_cte = serial_cte.cte() + next_job_stmt = select(InternalJobORM.scheduled_date) next_job_stmt = next_job_stmt.where(InternalJobORM.status == InternalJobStatusEnum.waiting) + next_job_stmt = next_job_stmt.where( + or_(InternalJobORM.serial_group.is_(None), InternalJobORM.serial_group.not_in(select(serial_cte))) + ) next_job_stmt = next_job_stmt.order_by(InternalJobORM.scheduled_date.asc()) # Skip any that are being claimed for running right now @@ -516,6 +530,11 @@ def run_loop(self, end_event): conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) # Prepare a statement for finding jobs. Filters will be added in the loop + serial_cte = select(InternalJobORM.serial_group).distinct() + serial_cte = serial_cte.where(InternalJobORM.status == InternalJobStatusEnum.running) + serial_cte = serial_cte.where(InternalJobORM.serial_group.is_not(None)) + serial_cte = serial_cte.cte() + stmt = select(InternalJobORM) stmt = stmt.order_by(InternalJobORM.scheduled_date.asc()).limit(1) stmt = stmt.with_for_update(skip_locked=True) @@ -528,10 +547,20 @@ def run_loop(self, end_event): logger.debug("checking for jobs") # Pick up anything waiting, or anything that hasn't been updated in a while (12 update periods) + now = now_at_utc() dead = now - timedelta(seconds=(self._update_frequency * 12)) logger.debug(f"checking for jobs before date {now}") - cond1 = and_(InternalJobORM.status == InternalJobStatusEnum.waiting, InternalJobORM.scheduled_date <= now) + + # Job is waiting, schedule to be run now or in the past, + # Serial group does not have any running or is not set + cond1 = and_( + InternalJobORM.status == InternalJobStatusEnum.waiting, + InternalJobORM.scheduled_date <= now, + or_(InternalJobORM.serial_group.is_(None), InternalJobORM.serial_group.not_in(select(serial_cte))), + ) + + # Job is running but runner is determined to be dead. Serial group doesn't matter cond2 = and_(InternalJobORM.status == InternalJobStatusEnum.running, InternalJobORM.last_updated < dead) stmt_now = stmt.where(or_(cond1, cond2)) @@ -560,8 +589,21 @@ def run_loop(self, end_event): job_orm.runner_uuid = runner_uuid job_orm.status = InternalJobStatusEnum.running - # Releases the row-level lock (from the with_for_update() in the original query) - session_main.commit() + # For logging below - the object might be in an odd state after the exception, where accessing + # it results in further exceptions + serial_group = job_orm.serial_group + + # Violation of the unique constraint may occur - two runners attempting to take + # different jobs of the same serial group at the same time + try: + # Releases the row-level lock (from the with_for_update() in the original query) + session_main.commit() + except IntegrityError: + logger.info( + f"Attempting to run job from serial group '{serial_group}, but seems like another runner got to another job from the same group first" + ) + session_main.rollback() + continue job_progress = JobProgress(job_orm.id, runner_uuid, session_status, self._update_frequency, end_event) self._run_single(session_main, job_orm, logger, job_progress=job_progress) diff --git a/qcfractal/qcfractal/components/internal_jobs/test_socket.py b/qcfractal/qcfractal/components/internal_jobs/test_socket.py index 8b39d79c4..092e7bef9 100644 --- a/qcfractal/qcfractal/components/internal_jobs/test_socket.py +++ b/qcfractal/qcfractal/components/internal_jobs/test_socket.py @@ -116,6 +116,50 @@ def test_internal_jobs_socket_run(storage_socket: SQLAlchemySocket, session: Ses th.join() +@pytest.mark.parametrize("job_func", ("internal_jobs.dummy_job", "internal_jobs.dummy_job_2")) +def test_internal_jobs_socket_run_serial(storage_socket: SQLAlchemySocket, session: Session, job_func: str): + id_1 = storage_socket.internal_jobs.add( + "dummy_job", now_at_utc(), job_func, {"iterations": 10}, None, unique_name=False, serial_group="test" + ) + id_2 = storage_socket.internal_jobs.add( + "dummy_job", now_at_utc(), job_func, {"iterations": 10}, None, unique_name=False, serial_group="test" + ) + id_3 = storage_socket.internal_jobs.add( + "dummy_job", + now_at_utc(), + job_func, + {"iterations": 10}, + None, + unique_name=False, + ) + + # Faster updates for testing + storage_socket.internal_jobs._update_frequency = 1 + + end_event = threading.Event() + th1 = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,)) + th2 = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,)) + th3 = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,)) + th1.start() + th2.start() + th3.start() + time.sleep(8) + + try: + job_1 = session.get(InternalJobORM, id_1) + job_2 = session.get(InternalJobORM, id_2) + job_3 = session.get(InternalJobORM, id_3) + assert job_1.status == InternalJobStatusEnum.running + assert job_2.status == InternalJobStatusEnum.waiting + assert job_3.status == InternalJobStatusEnum.running + + finally: + end_event.set() + th1.join() + th2.join() + th3.join() + + def test_internal_jobs_socket_recover(storage_socket: SQLAlchemySocket, session: Session): id_1 = storage_socket.internal_jobs.add( "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False diff --git a/qcportal/qcportal/internal_jobs/models.py b/qcportal/qcportal/internal_jobs/models.py index 1ab1dd7cd..2d5caa76f 100644 --- a/qcportal/qcportal/internal_jobs/models.py +++ b/qcportal/qcportal/internal_jobs/models.py @@ -57,6 +57,7 @@ class Config: runner_hostname: Optional[str] runner_uuid: Optional[str] repeat_delay: Optional[int] + serial_group: Optional[str] progress: int progress_description: Optional[str] = None From 8709442073784751482dcf0bb797a3417cc6e5f8 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 13 Jan 2025 15:24:48 -0500 Subject: [PATCH 17/26] Better handling of cancelling internal jobs --- .../components/internal_jobs/socket.py | 64 ++++++++++++++----- .../components/internal_jobs/status.py | 39 +++++++++-- .../components/internal_jobs/test_client.py | 42 ++++++------ .../components/internal_jobs/test_socket.py | 55 ++++++++++++++-- 4 files changed, 155 insertions(+), 45 deletions(-) diff --git a/qcfractal/qcfractal/components/internal_jobs/socket.py b/qcfractal/qcfractal/components/internal_jobs/socket.py index 32fcb3d2c..ac746d650 100644 --- a/qcfractal/qcfractal/components/internal_jobs/socket.py +++ b/qcfractal/qcfractal/components/internal_jobs/socket.py @@ -21,7 +21,7 @@ from qcportal.internal_jobs.models import InternalJobStatusEnum, InternalJobQueryFilters from qcportal.utils import now_at_utc from .db_models import InternalJobORM -from .status import JobProgress +from .status import JobProgress, CancelledJobException, JobRunnerStoppingException if TYPE_CHECKING: from qcfractal.db_socket.socket import SQLAlchemySocket @@ -333,6 +333,9 @@ def _run_single(self, session: Session, job_orm: InternalJobORM, logger, job_pro Runs a single job """ + # For logging (ORM may end up detached or somthing) + job_id = job_orm.id + try: func_attr = attrgetter(job_orm.function) @@ -351,29 +354,55 @@ def _run_single(self, session: Session, job_orm: InternalJobORM, logger, job_pro if "session" in func_params: add_kwargs["session"] = session + # Run the desired function + # Raises an exception if cancelled result = func(**job_orm.kwargs, **add_kwargs) - # Mark complete, unless this was cancelled - if not job_progress.cancelled(): - job_orm.status = InternalJobStatusEnum.complete - job_orm.progress = 100 - job_orm.progress_description = "Complete" + job_orm.status = InternalJobStatusEnum.complete + job_orm.progress = 100 + job_orm.progress_description = "Complete" + job_orm.result = result - except Exception: + # Job itself is being cancelled + except CancelledJobException: session.rollback() - result = traceback.format_exc() - logger.error(f"Job {job_orm.id} failed with exception:\n{result}") + logger.info(f"Job {job_id} was cancelled") + job_orm.status = InternalJobStatusEnum.cancelled + job_orm.result = None + # Runner is stopping down + except JobRunnerStoppingException: + session.rollback() + logger.info(f"Job {job_id} was running, but runner is stopping") + + # Basically reset everything + job_orm.status = InternalJobStatusEnum.waiting + job_orm.progress = 0 + job_orm.progress_description = None + job_orm.started_date = None + job_orm.last_updated = None + job_orm.result = None + job_orm.runner_uuid = None + + # Function itself had an error + except: + session.rollback() + job_orm.result = traceback.format_exc() + logger.error(f"Job {job_id} failed with exception:\n{job_orm.result}") job_orm.status = InternalJobStatusEnum.error - if not job_progress.deleted(): - job_orm.ended_date = now_at_utc() - job_orm.last_updated = job_orm.ended_date - job_orm.result = result + if job_progress.deleted: + # Row does not exist anymore + session.expunge(job_orm) + else: + # If status is waiting, that means the runner itself is stopping or something + if job_orm.status != InternalJobStatusEnum.waiting: + job_orm.ended_date = now_at_utc() + job_orm.last_updated = job_orm.ended_date - # Clear the unique name so we can add another one if needed - has_unique_name = job_orm.unique_name is not None - job_orm.unique_name = None + # Clear the unique name so we can add another one if needed + has_unique_name = job_orm.unique_name is not None + job_orm.unique_name = None # Flush but don't commit. This will prevent marking the task as finished # before the after_func has been run, but allow new ones to be added @@ -411,7 +440,8 @@ def _run_single(self, session: Session, job_orm: InternalJobORM, logger, job_pro repeat_delay=job_orm.repeat_delay, session=session, ) - session.commit() + + session.commit() @staticmethod def _wait_for_job(session: Session, logger, conn, end_event): diff --git a/qcfractal/qcfractal/components/internal_jobs/status.py b/qcfractal/qcfractal/components/internal_jobs/status.py index 6270cb813..159f73d47 100644 --- a/qcfractal/qcfractal/components/internal_jobs/status.py +++ b/qcfractal/qcfractal/components/internal_jobs/status.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Optional import threading import weakref +from typing import Optional from sqlalchemy import update from sqlalchemy.orm import Session @@ -12,6 +12,14 @@ from qcportal.utils import now_at_utc +class CancelledJobException(Exception): + pass + + +class JobRunnerStoppingException(Exception): + pass + + class JobProgress: """ Functor for updating progress and cancelling internal jobs @@ -30,6 +38,7 @@ def __init__(self, job_id: int, runner_uuid: str, session: Session, update_frequ self._description = None self._cancelled = False + self._runner_ending = False self._deleted = False self._end_event = end_event @@ -65,9 +74,9 @@ def _update_thread(self, session: Session, end_thread: threading.Event): # Job was stolen from us? self._cancelled = True - # Are we completely ending? + # Are we ending/cancelling because the runner is stopping/closing? if self._end_event.is_set(): - self._cancelled = True + self._runner_ending = True cancel = end_thread.wait(self._update_frequency) if cancel is True: @@ -89,8 +98,30 @@ def update_progress(self, progress: int, description: Optional[str] = None): self._progress = progress self._description = description + @property def cancelled(self) -> bool: - return self._cancelled + return self._cancelled or self._deleted + + @property + def runner_ending(self) -> bool: + return self._runner_ending + @property def deleted(self) -> bool: return self._deleted + + def raise_if_cancelled(self): + if self._cancelled or self._deleted: + raise CancelledJobException("Job was cancelled or deleted") + + if self._runner_ending: + raise JobRunnerStoppingException("Job runner is stopping/ending") + + +def raise_if_cancelled(job_progress: Optional[JobProgress]): + """ + Raises a CancelledJobExcepion if job_progress exists and is in a cancelled or deleted state + """ + + if job_progress is not None: + job_progress.raise_if_cancelled() diff --git a/qcfractal/qcfractal/components/internal_jobs/test_client.py b/qcfractal/qcfractal/components/internal_jobs/test_client.py index 3e83eb3e6..4e01fc049 100644 --- a/qcfractal/qcfractal/components/internal_jobs/test_client.py +++ b/qcfractal/qcfractal/components/internal_jobs/test_client.py @@ -23,10 +23,9 @@ def dummmy_internal_job(self, iterations: int, session, job_progress): for i in range(iterations): time.sleep(1.0) job_progress.update_progress(100 * ((i + 1) / iterations), f"Interation {i} of {iterations}") - print("Dummy internal job counter ", i) + #print("Dummy internal job counter ", i) - if job_progress.cancelled(): - return "Internal job cancelled" + job_progress.raise_if_cancelled() return "Internal job finished" @@ -36,8 +35,8 @@ def dummmy_internal_job_error(self, session, job_progress): raise RuntimeError("Expected error") -setattr(InternalJobSocket, "dummy_job", dummmy_internal_job) -setattr(InternalJobSocket, "dummy_job_error", dummmy_internal_job_error) +setattr(InternalJobSocket, "client_dummy_job", dummmy_internal_job) +setattr(InternalJobSocket, "client_dummy_job_error", dummmy_internal_job_error) def test_internal_jobs_client_error(snowflake: QCATestingSnowflake): @@ -45,7 +44,7 @@ def test_internal_jobs_client_error(snowflake: QCATestingSnowflake): snowflake_client = snowflake.client() id_1 = storage_socket.internal_jobs.add( - "dummy_job_error", now_at_utc(), "internal_jobs.dummy_job_error", {}, None, unique_name=False + "client_dummy_job_error", now_at_utc(), "internal_jobs.client_dummy_job_error", {}, None, unique_name=False ) # Faster updates for testing @@ -73,7 +72,7 @@ def test_internal_jobs_client_cancel_waiting(snowflake: QCATestingSnowflake): snowflake_client = snowflake.client() id_1 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False + "client_dummy_job", now_at_utc(), "internal_jobs.client.dummy_job", {"iterations": 10}, None, unique_name=False ) snowflake_client.cancel_internal_job(id_1) @@ -90,7 +89,7 @@ def test_internal_jobs_client_cancel_running(snowflake: QCATestingSnowflake): snowflake_client = snowflake.client() id_1 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False + "client_dummy_job", now_at_utc(), "internal_jobs.client_dummy_job", {"iterations": 10}, None, unique_name=False ) # Faster updates for testing @@ -112,7 +111,7 @@ def test_internal_jobs_client_cancel_running(snowflake: QCATestingSnowflake): job_1 = snowflake_client.get_internal_job(id_1) assert job_1.status == InternalJobStatusEnum.cancelled assert job_1.progress < 70 - assert job_1.result == "Internal job cancelled" + assert job_1.result is None finally: end_event.set() @@ -124,7 +123,7 @@ def test_internal_jobs_client_delete_waiting(snowflake: QCATestingSnowflake): snowflake_client = snowflake.client() id_1 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False + "client_dummy_job", now_at_utc(), "internal_jobs.client_dummy_job", {"iterations": 10}, None, unique_name=False ) snowflake_client.delete_internal_job(id_1) @@ -138,7 +137,7 @@ def test_internal_jobs_client_delete_running(snowflake: QCATestingSnowflake): snowflake_client = snowflake.client() id_1 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False + "client_dummy_job", now_at_utc(), "internal_jobs.client_dummy_job", {"iterations": 10}, None, unique_name=False ) # Faster updates for testing @@ -173,7 +172,12 @@ def test_internal_jobs_client_query(secure_snowflake: QCATestingSnowflake): time_0 = now_at_utc() id_1 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 1}, read_id, unique_name=False + "client_dummy_job", + now_at_utc(), + "internal_jobs.client_dummy_job", + {"iterations": 1}, + read_id, + unique_name=False, ) time_1 = now_at_utc() @@ -196,7 +200,7 @@ def test_internal_jobs_client_query(secure_snowflake: QCATestingSnowflake): # Add one that will be waiting id_2 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 1}, None, unique_name=False + "client_dummy_job", now_at_utc(), "internal_jobs.client_dummy_job", {"iterations": 1}, None, unique_name=False ) time_3 = now_at_utc() @@ -217,31 +221,31 @@ def test_internal_jobs_client_query(secure_snowflake: QCATestingSnowflake): assert len(r) == 1 assert r[0].id == id_1 - result = client.query_internal_jobs(name="dummy_job", status=["complete"]) + result = client.query_internal_jobs(name="client_dummy_job", status=["complete"]) r = list(result) assert len(r) == 1 assert r[0].id == id_1 - result = client.query_internal_jobs(name="dummy_job", status="waiting") + result = client.query_internal_jobs(name="client_dummy_job", status="waiting") r = list(result) assert len(r) == 1 assert r[0].id == id_2 - result = client.query_internal_jobs(name="dummy_job", added_after=time_0) + result = client.query_internal_jobs(name="client_dummy_job", added_after=time_0) r = list(result) assert len(r) == 2 assert {r[0].id, r[1].id} == {id_1, id_2} - result = client.query_internal_jobs(name="dummy_job", added_after=time_1, added_before=time_3) + result = client.query_internal_jobs(name="client_dummy_job", added_after=time_1, added_before=time_3) r = list(result) assert len(r) == 1 assert r[0].id == id_2 - result = client.query_internal_jobs(name="dummy_job", last_updated_after=time_2) + result = client.query_internal_jobs(name="client_dummy_job", last_updated_after=time_2) r = list(result) assert len(r) == 0 - result = client.query_internal_jobs(name="dummy_job", last_updated_before=time_2) + result = client.query_internal_jobs(name="client_dummy_job", last_updated_before=time_2) r = list(result) assert len(r) == 1 assert r[0].id == id_1 diff --git a/qcfractal/qcfractal/components/internal_jobs/test_socket.py b/qcfractal/qcfractal/components/internal_jobs/test_socket.py index 092e7bef9..ee2879285 100644 --- a/qcfractal/qcfractal/components/internal_jobs/test_socket.py +++ b/qcfractal/qcfractal/components/internal_jobs/test_socket.py @@ -3,6 +3,7 @@ import threading import time import uuid +from datetime import timedelta from typing import TYPE_CHECKING import pytest @@ -26,8 +27,7 @@ def dummy_internal_job(self, iterations: int, session, job_progress): job_progress.update_progress(100 * ((i + 1) / iterations)) # print("Dummy internal job counter ", i) - if job_progress.cancelled(): - return "Internal job cancelled" + job_progress.raise_if_cancelled() return "Internal job finished" @@ -160,7 +160,7 @@ def test_internal_jobs_socket_run_serial(storage_socket: SQLAlchemySocket, sessi th3.join() -def test_internal_jobs_socket_recover(storage_socket: SQLAlchemySocket, session: Session): +def test_internal_jobs_socket_runnerstop(storage_socket: SQLAlchemySocket, session: Session): id_1 = storage_socket.internal_jobs.add( "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False ) @@ -179,8 +179,13 @@ def test_internal_jobs_socket_recover(storage_socket: SQLAlchemySocket, session: assert not th.is_alive() job_1 = session.get(InternalJobORM, id_1) - assert job_1.status == InternalJobStatusEnum.running - assert job_1.progress > 10 + assert job_1.status == InternalJobStatusEnum.waiting + assert job_1.progress == 0 + assert job_1.started_date is None + assert job_1.last_updated is None + assert job_1.runner_uuid is None + + return old_uuid = job_1.runner_uuid # Change uuid @@ -203,3 +208,43 @@ def test_internal_jobs_socket_recover(storage_socket: SQLAlchemySocket, session: finally: end_event.set() th.join() + + +def test_internal_jobs_socket_recover(storage_socket: SQLAlchemySocket, session: Session): + id_1 = storage_socket.internal_jobs.add( + "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 5}, None, unique_name=False + ) + + # Faster updates for testing + storage_socket.internal_jobs._update_frequency = 1 + + # Manually make it seem like it's running + old_uuid = str(uuid.uuid4()) + job_1 = session.get(InternalJobORM, id_1) + job_1.status = InternalJobStatusEnum.running + job_1.progress = 10 + job_1.last_updated = now_at_utc() - timedelta(seconds=60) + job_1.runner_uuid = old_uuid + session.commit() + + session.expire(job_1) + job_1 = session.get(InternalJobORM, id_1) + assert job_1.status == InternalJobStatusEnum.running + assert job_1.runner_uuid == old_uuid + + # Job is now running but orphaned. Should be picked up next time + end_event = threading.Event() + th = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,)) + th.start() + time.sleep(10) + + try: + session.expire(job_1) + job_1 = session.get(InternalJobORM, id_1) + assert job_1.status == InternalJobStatusEnum.complete + assert job_1.runner_uuid != old_uuid + assert job_1.progress == 100 + assert job_1.result == "Internal job finished" + finally: + end_event.set() + th.join() From 8472174c7e4a3f1e28458528e46c2b0454ed2751 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 13 Jan 2025 15:43:57 -0500 Subject: [PATCH 18/26] Better output from internal job watch() when it finishes --- qcportal/qcportal/internal_jobs/models.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/qcportal/qcportal/internal_jobs/models.py b/qcportal/qcportal/internal_jobs/models.py index 2d5caa76f..a4e46b6c3 100644 --- a/qcportal/qcportal/internal_jobs/models.py +++ b/qcportal/qcportal/internal_jobs/models.py @@ -132,8 +132,14 @@ def watch(self, interval: float = 2.0, timeout: Optional[float] = None): if end_time is not None and t >= end_time: raise TimeoutError("Timed out waiting for job to complete") - if self.status not in [InternalJobStatusEnum.waiting, InternalJobStatusEnum.running]: + if self.status == InternalJobStatusEnum.error: + print("Internal job resulted in an error:") + print(self.result) break + elif self.status not in [InternalJobStatusEnum.waiting, InternalJobStatusEnum.running]: + print(f"Internal job final status: {self.status.value}") + break + curtime = time.time() if end_time is not None: From 801bc12b5af53e9ad0d2000beb4b3926d067d432 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 13 Jan 2025 15:25:09 -0500 Subject: [PATCH 19/26] Make server side view creation cancellable --- .../components/dataset_processing/views.py | 10 ++++++-- .../qcfractal/components/dataset_socket.py | 10 ++++++-- .../components/external_files/socket.py | 23 ++++++++++++++----- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/qcfractal/qcfractal/components/dataset_processing/views.py b/qcfractal/qcfractal/components/dataset_processing/views.py index a2d094dfd..d85d43b51 100644 --- a/qcfractal/qcfractal/components/dataset_processing/views.py +++ b/qcfractal/qcfractal/components/dataset_processing/views.py @@ -7,15 +7,15 @@ from sqlalchemy import select from sqlalchemy.orm import selectinload +from qcfractal.components.internal_jobs.status import JobProgress from qcfractal.components.record_db_models import BaseRecordORM +from qcportal.cache import DatasetCache from qcportal.dataset_models import BaseDataset from qcportal.record_models import RecordStatusEnum, BaseRecord from qcportal.utils import chunk_iterable -from qcportal.cache import DatasetCache if TYPE_CHECKING: from sqlalchemy.orm.session import Session - from qcfractal.components.internal_jobs.status import JobProgress from qcfractal.db_socket.socket import SQLAlchemySocket from typing import Optional, Iterable from typing import Iterable @@ -88,6 +88,7 @@ def create_view_file( # Entries if job_progress is not None: + job_progress.raise_if_cancelled() job_progress.update_progress(0, "Processing dataset entries") stmt = select(ds_socket.entry_orm) @@ -99,6 +100,7 @@ def create_view_file( view_db.update_entries(entries) if job_progress is not None: + job_progress.raise_if_cancelled() job_progress.update_progress(5, "Processing dataset specifications") # Specifications @@ -111,6 +113,7 @@ def create_view_file( view_db.update_specifications(specs) if job_progress is not None: + job_progress.raise_if_cancelled() job_progress.update_progress(10, "Loading record information") # Now all the records @@ -143,6 +146,7 @@ def create_view_file( record_type_map[record_type].append(record_id) if job_progress is not None: + job_progress.raise_if_cancelled() job_progress.update_progress(15, "Processing individual records") ############################################################################ @@ -163,6 +167,8 @@ def create_view_file( finished_count += len(id_chunk) if job_progress is not None: + job_progress.raise_if_cancelled() + # Fraction of the 75% left over (15 to start, 10 left over for uploading) job_progress.update_progress( 15 + int(75 * finished_count / record_count), "Processing individual records" diff --git a/qcfractal/qcfractal/components/dataset_socket.py b/qcfractal/qcfractal/components/dataset_socket.py index 97df06157..4d3aff78f 100644 --- a/qcfractal/qcfractal/components/dataset_socket.py +++ b/qcfractal/qcfractal/components/dataset_socket.py @@ -1738,6 +1738,7 @@ def attach_file( description: Optional[str], provenance: Dict[str, Any], *, + job_progress: Optional[JobProgress] = None, session: Optional[Session] = None, ) -> int: """ @@ -1761,6 +1762,8 @@ def attach_file( An optional description of the file provenance A dictionary containing metadata regarding the origin or history of the file. + job_progress + Object used to track progress if this function is being run in a background job session An existing SQLAlchemy session to use. If None, one will be created. If an existing session is used, it will be flushed (but not committed) before returning from this function. @@ -1784,7 +1787,10 @@ def attach_file( provenance=provenance, ) - file_id = self.root_socket.external_files.add_file(file_path, ef, session=session) + file_id = self.root_socket.external_files.add_file( + file_path, ef, session=session, job_progress=job_progress + ) + self._logger.info(f"Dataset attachment {file_path} successfully uploaded to S3. ID is {file_id}") return file_id @@ -1828,7 +1834,7 @@ def create_view_attachment( Specifies whether child records associated with the main records should also be included (recursively) in the view file. job_progress - Object used to track the progress of the job + Object used to track progress if this function is being run in a background job session An existing SQLAlchemy session to use. If None, one will be created. If an existing session is used, it will be flushed (but not committed) before returning from this function. diff --git a/qcfractal/qcfractal/components/external_files/socket.py b/qcfractal/qcfractal/components/external_files/socket.py index 309eacfc1..5cc2c4fb4 100644 --- a/qcfractal/qcfractal/components/external_files/socket.py +++ b/qcfractal/qcfractal/components/external_files/socket.py @@ -7,20 +7,23 @@ import uuid from typing import TYPE_CHECKING -# Torsiondrive package is optional +from sqlalchemy import select + +from qcportal.exceptions import MissingDataError +from qcportal.external_files import ExternalFileTypeEnum, ExternalFileStatusEnum +from .db_models import ExternalFileORM + +# Boto3 package is optional _boto3_spec = importlib.util.find_spec("boto3") if _boto3_spec is not None: boto3 = importlib.util.module_from_spec(_boto3_spec) _boto3_spec.loader.exec_module(boto3) -from qcportal.exceptions import MissingDataError -from qcportal.external_files import ExternalFileTypeEnum, ExternalFileStatusEnum -from .db_models import ExternalFileORM -from sqlalchemy import select if TYPE_CHECKING: from sqlalchemy.orm.session import Session + from qcfractal.components.internal_jobs.status import JobProgress from qcfractal.db_socket.socket import SQLAlchemySocket from typing import Optional, Dict, Any, Tuple, Union, BinaryIO, Callable, Generator @@ -73,6 +76,7 @@ def add_data( file_data: BinaryIO, file_orm: ExternalFileORM, *, + job_progress: Optional[JobProgress] = None, session: Optional[Session] = None, ) -> int: """ @@ -91,6 +95,8 @@ def add_data( Binary data to be read from file_orm Existing ORM object that will be filled in with metadata + job_progress + Object used to track progress if this function is being run in a background job session An existing SQLAlchemy session to use. If None, one will be created. If an existing session is used, it will be flushed (but not committed) before returning from this function. @@ -127,6 +133,8 @@ def add_data( try: while chunk := file_data.read(10 * 1024 * 1024): + job_progress.raise_if_cancelled() + sha256.update(chunk) file_size += len(chunk) @@ -157,6 +165,7 @@ def add_file( file_path: str, file_orm: ExternalFileORM, *, + job_progress: Optional[JobProgress] = None, session: Optional[Session] = None, ) -> int: """ @@ -170,6 +179,8 @@ def add_file( Path to an existing file to be read from file_orm Existing ORM object that will be filled in with metadata + job_progress + Object used to track progress if this function is being run in a background job session An existing SQLAlchemy session to use. If None, one will be created. If an existing session is used, it will be flushed (but not committed) before returning from this function. @@ -183,7 +194,7 @@ def add_file( self._logger.info(f"Uploading {file_path} to S3. File size: {os.path.getsize(file_path)/1048576} MiB") with open(file_path, "rb") as f: - return self.add_data(f, file_orm, session=session) + return self.add_data(f, file_orm, job_progress=job_progress, session=session) def get_metadata( self, From cd3c64d5aad4872c634e5f9a55b61403fe84f21e Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 13 Jan 2025 15:44:29 -0500 Subject: [PATCH 20/26] Quieter internal job tests --- qcfractal/qcfractal/components/external_files/socket.py | 3 ++- qcfractal/qcfractal/components/internal_jobs/test_client.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/qcfractal/qcfractal/components/external_files/socket.py b/qcfractal/qcfractal/components/external_files/socket.py index 5cc2c4fb4..f939648bd 100644 --- a/qcfractal/qcfractal/components/external_files/socket.py +++ b/qcfractal/qcfractal/components/external_files/socket.py @@ -133,7 +133,8 @@ def add_data( try: while chunk := file_data.read(10 * 1024 * 1024): - job_progress.raise_if_cancelled() + if job_progress is not None: + job_progress.raise_if_cancelled() sha256.update(chunk) file_size += len(chunk) diff --git a/qcfractal/qcfractal/components/internal_jobs/test_client.py b/qcfractal/qcfractal/components/internal_jobs/test_client.py index 4e01fc049..db9180598 100644 --- a/qcfractal/qcfractal/components/internal_jobs/test_client.py +++ b/qcfractal/qcfractal/components/internal_jobs/test_client.py @@ -23,7 +23,7 @@ def dummmy_internal_job(self, iterations: int, session, job_progress): for i in range(iterations): time.sleep(1.0) job_progress.update_progress(100 * ((i + 1) / iterations), f"Interation {i} of {iterations}") - #print("Dummy internal job counter ", i) + # print("Dummy internal job counter ", i) job_progress.raise_if_cancelled() From 0ebbb0c75ded5ceee70f6e799fb3232f8091b7a5 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 13 Jan 2025 15:43:23 -0500 Subject: [PATCH 21/26] Create temporary_dir if it doesn't exist --- qcfractal/qcfractal/config.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/qcfractal/qcfractal/config.py b/qcfractal/qcfractal/config.py index fadc58b7e..a8aa8e938 100644 --- a/qcfractal/qcfractal/config.py +++ b/qcfractal/qcfractal/config.py @@ -7,6 +7,7 @@ import logging import os import secrets +import tempfile from typing import Optional, Dict, Union, Any import yaml @@ -508,6 +509,15 @@ def _convert_durations_days(cls, v): return int(v) * 86400 return duration_to_seconds(v) + @validator("temporary_dir", pre=True) + def _create_temporary_directory(cls, v, values): + v = _make_abs_path(v, values["base_folder"], tempfile.gettempdir()) + + if v is not None and not os.path.exists(v): + os.makedirs(v) + + return v + class Config(ConfigCommon): env_prefix = "QCF_" From 38f2da0e909f1557daae4a7522c80f7c2c87e94c Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 13 Jan 2025 16:39:27 -0500 Subject: [PATCH 22/26] Improve dataset testing and remove duplicate submit tests --- .../gridoptimization/test_dataset_client.py | 79 ------------------- .../manybody/test_dataset_client.py | 79 ------------------- .../components/neb/test_dataset_client.py | 79 ------------------- .../optimization/test_dataset_client.py | 79 ------------------- .../reaction/test_dataset_client.py | 79 ------------------- .../singlepoint/test_dataset_client.py | 79 ------------------- .../torsiondrive/test_dataset_client.py | 79 ------------------- qcportal/qcportal/dataset_testing_helpers.py | 11 +++ .../gridoptimization/test_dataset_models.py | 10 ++- .../qcportal/manybody/test_dataset_models.py | 10 ++- qcportal/qcportal/neb/test_dataset_models.py | 6 +- .../optimization/test_dataset_models.py | 10 ++- .../qcportal/reaction/test_dataset_models.py | 10 ++- .../singlepoint/test_dataset_models.py | 10 ++- .../torsiondrive/test_dataset_models.py | 10 ++- 15 files changed, 56 insertions(+), 574 deletions(-) delete mode 100644 qcfractal/qcfractal/components/gridoptimization/test_dataset_client.py delete mode 100644 qcfractal/qcfractal/components/manybody/test_dataset_client.py delete mode 100644 qcfractal/qcfractal/components/neb/test_dataset_client.py delete mode 100644 qcfractal/qcfractal/components/optimization/test_dataset_client.py delete mode 100644 qcfractal/qcfractal/components/reaction/test_dataset_client.py delete mode 100644 qcfractal/qcfractal/components/singlepoint/test_dataset_client.py delete mode 100644 qcfractal/qcfractal/components/torsiondrive/test_dataset_client.py diff --git a/qcfractal/qcfractal/components/gridoptimization/test_dataset_client.py b/qcfractal/qcfractal/components/gridoptimization/test_dataset_client.py deleted file mode 100644 index d4b3099a5..000000000 --- a/qcfractal/qcfractal/components/gridoptimization/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.gridoptimization import GridoptimizationDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def gridoptimization_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "gridoptimization", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_gridoptimization_dataset_client_submit(gridoptimization_ds: GridoptimizationDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("go_H3NS_psi4_pbe") - - gridoptimization_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - gridoptimization_ds.add_specification("test_spec", input_spec_1, "test_specification") - - gridoptimization_ds.submit() - assert gridoptimization_ds.status()["test_spec"]["waiting"] == 1 - - gridoptimization_ds.submit() - assert gridoptimization_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in gridoptimization_ds.iterate_records(): - assert r.owner_user == gridoptimization_ds.owner_user - assert r.owner_group == gridoptimization_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - gridoptimization_ds.delete_entries(["test_molecule"]) - assert gridoptimization_ds.status() == {} - - # record still on the server? - r = gridoptimization_ds._client.get_records(record_id) - assert r.owner_user == gridoptimization_ds.owner_user - - # now resubmit - gridoptimization_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - gridoptimization_ds.submit(find_existing=find_existing) - assert gridoptimization_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in gridoptimization_ds.iterate_records(): - assert r.owner_user == gridoptimization_ds.owner_user - assert r.owner_group == gridoptimization_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/manybody/test_dataset_client.py b/qcfractal/qcfractal/components/manybody/test_dataset_client.py deleted file mode 100644 index 7adb43156..000000000 --- a/qcfractal/qcfractal/components/manybody/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.manybody import ManybodyDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def manybody_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "manybody", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_manybody_dataset_client_submit(manybody_ds: ManybodyDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("mb_cp_he4_psi4_mp2") - - manybody_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - manybody_ds.add_specification("test_spec", input_spec_1, "test_specification") - - manybody_ds.submit() - assert manybody_ds.status()["test_spec"]["waiting"] == 1 - - manybody_ds.submit() - assert manybody_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in manybody_ds.iterate_records(): - assert r.owner_user == manybody_ds.owner_user - assert r.owner_group == manybody_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - manybody_ds.delete_entries(["test_molecule"]) - assert manybody_ds.status() == {} - - # record still on the server? - r = manybody_ds._client.get_records(record_id) - assert r.owner_user == manybody_ds.owner_user - - # now resubmit - manybody_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - manybody_ds.submit(find_existing=find_existing) - assert manybody_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in manybody_ds.iterate_records(): - assert r.owner_user == manybody_ds.owner_user - assert r.owner_group == manybody_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/neb/test_dataset_client.py b/qcfractal/qcfractal/components/neb/test_dataset_client.py deleted file mode 100644 index 242e0f145..000000000 --- a/qcfractal/qcfractal/components/neb/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.neb import NEBDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def neb_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "neb", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_neb_dataset_client_submit(neb_ds: NEBDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("neb_HCN_psi4_pbe_opt_diff") - - neb_ds.add_entry(name="test_molecule", initial_chain=molecule_1) - neb_ds.add_specification("test_spec", input_spec_1, "test_specification") - - neb_ds.submit() - assert neb_ds.status()["test_spec"]["waiting"] == 1 - - neb_ds.submit() - assert neb_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in neb_ds.iterate_records(): - assert r.owner_user == neb_ds.owner_user - assert r.owner_group == neb_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - neb_ds.delete_entries(["test_molecule"]) - assert neb_ds.status() == {} - - # record still on the server? - r = neb_ds._client.get_records(record_id) - assert r.owner_user == neb_ds.owner_user - - # now resubmit - neb_ds.add_entry(name="test_molecule", initial_chain=molecule_1) - neb_ds.submit(find_existing=find_existing) - assert neb_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in neb_ds.iterate_records(): - assert r.owner_user == neb_ds.owner_user - assert r.owner_group == neb_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/optimization/test_dataset_client.py b/qcfractal/qcfractal/components/optimization/test_dataset_client.py deleted file mode 100644 index 33e19eccd..000000000 --- a/qcfractal/qcfractal/components/optimization/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.optimization import OptimizationDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def optimization_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "optimization", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_optimization_dataset_client_submit(optimization_ds: OptimizationDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("opt_psi4_benzene") - - optimization_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - optimization_ds.add_specification("test_spec", input_spec_1, "test_specification") - - optimization_ds.submit() - assert optimization_ds.status()["test_spec"]["waiting"] == 1 - - optimization_ds.submit() - assert optimization_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in optimization_ds.iterate_records(): - assert r.owner_user == optimization_ds.owner_user - assert r.owner_group == optimization_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - optimization_ds.delete_entries(["test_molecule"]) - assert optimization_ds.status() == {} - - # record still on the server? - r = optimization_ds._client.get_records(record_id) - assert r.owner_user == optimization_ds.owner_user - - # now resubmit - optimization_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - optimization_ds.submit(find_existing=find_existing) - assert optimization_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in optimization_ds.iterate_records(): - assert r.owner_user == optimization_ds.owner_user - assert r.owner_group == optimization_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/reaction/test_dataset_client.py b/qcfractal/qcfractal/components/reaction/test_dataset_client.py deleted file mode 100644 index 7efbf3632..000000000 --- a/qcfractal/qcfractal/components/reaction/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.reaction import ReactionDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def reaction_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "reaction", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_reaction_dataset_client_submit(reaction_ds: ReactionDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("rxn_H2O_psi4_mp2_optsp") - - reaction_ds.add_entry(name="test_molecule", stoichiometries=molecule_1) - reaction_ds.add_specification("test_spec", input_spec_1, "test_specification") - - reaction_ds.submit() - assert reaction_ds.status()["test_spec"]["waiting"] == 1 - - reaction_ds.submit() - assert reaction_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in reaction_ds.iterate_records(): - assert r.owner_user == reaction_ds.owner_user - assert r.owner_group == reaction_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - reaction_ds.delete_entries(["test_molecule"]) - assert reaction_ds.status() == {} - - # record still on the server? - r = reaction_ds._client.get_records(record_id) - assert r.owner_user == reaction_ds.owner_user - - # now resubmit - reaction_ds.add_entry(name="test_molecule", stoichiometries=molecule_1) - reaction_ds.submit(find_existing=find_existing) - assert reaction_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in reaction_ds.iterate_records(): - assert r.owner_user == reaction_ds.owner_user - assert r.owner_group == reaction_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/singlepoint/test_dataset_client.py b/qcfractal/qcfractal/components/singlepoint/test_dataset_client.py deleted file mode 100644 index 84ae31f20..000000000 --- a/qcfractal/qcfractal/components/singlepoint/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.record_models import PriorityEnum -from qcportal.singlepoint import SinglepointDataset -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def singlepoint_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "singlepoint", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_singlepoint_dataset_client_submit(singlepoint_ds: SinglepointDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("sp_psi4_benzene_energy_1") - - singlepoint_ds.add_entry(name="test_molecule", molecule=molecule_1) - singlepoint_ds.add_specification("test_spec", input_spec_1, "test_specification") - - singlepoint_ds.submit() - assert singlepoint_ds.status()["test_spec"]["waiting"] == 1 - - singlepoint_ds.submit() - assert singlepoint_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in singlepoint_ds.iterate_records(): - assert r.owner_user == singlepoint_ds.owner_user - assert r.owner_group == singlepoint_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - singlepoint_ds.delete_entries(["test_molecule"]) - assert singlepoint_ds.status() == {} - - # record still on the server? - r = singlepoint_ds._client.get_records(record_id) - assert r.owner_user == singlepoint_ds.owner_user - - # now resubmit - singlepoint_ds.add_entry(name="test_molecule", molecule=molecule_1) - singlepoint_ds.submit(find_existing=find_existing) - assert singlepoint_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in singlepoint_ds.iterate_records(): - assert r.owner_user == singlepoint_ds.owner_user - assert r.owner_group == singlepoint_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/torsiondrive/test_dataset_client.py b/qcfractal/qcfractal/components/torsiondrive/test_dataset_client.py deleted file mode 100644 index f9f159e78..000000000 --- a/qcfractal/qcfractal/components/torsiondrive/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.torsiondrive import TorsiondriveDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def torsiondrive_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "torsiondrive", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_torsiondrive_dataset_client_submit(torsiondrive_ds: TorsiondriveDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("td_H2O2_mopac_pm6") - - torsiondrive_ds.add_entry(name="test_molecule", initial_molecules=molecule_1) - torsiondrive_ds.add_specification("test_spec", input_spec_1, "test_specification") - - torsiondrive_ds.submit() - assert torsiondrive_ds.status()["test_spec"]["waiting"] == 1 - - torsiondrive_ds.submit() - assert torsiondrive_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in torsiondrive_ds.iterate_records(): - assert r.owner_user == torsiondrive_ds.owner_user - assert r.owner_group == torsiondrive_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - torsiondrive_ds.delete_entries(["test_molecule"]) - assert torsiondrive_ds.status() == {} - - # record still on the server? - r = torsiondrive_ds._client.get_records(record_id) - assert r.owner_user == torsiondrive_ds.owner_user - - # now resubmit - torsiondrive_ds.add_entry(name="test_molecule", initial_molecules=molecule_1) - torsiondrive_ds.submit(find_existing=find_existing) - assert torsiondrive_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in torsiondrive_ds.iterate_records(): - assert r.owner_user == torsiondrive_ds.owner_user - assert r.owner_group == torsiondrive_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcportal/qcportal/dataset_testing_helpers.py b/qcportal/qcportal/dataset_testing_helpers.py index af0f71d4b..babb5bdf3 100644 --- a/qcportal/qcportal/dataset_testing_helpers.py +++ b/qcportal/qcportal/dataset_testing_helpers.py @@ -441,6 +441,9 @@ def run_dataset_model_submit(ds, test_entries, test_spec, record_compare): record_compare(rec, test_entries[0], test_spec) + assert rec.owner_user == "submit_user" + assert rec.owner_group == "group1" + # Used default tag/priority if rec.is_service: assert rec.service.tag == "default_tag" @@ -477,6 +480,14 @@ def run_dataset_model_submit(ds, test_entries, test_spec, record_compare): assert rec.task.tag == "default_tag" assert rec.task.priority == PriorityEnum.low + # Don't find existing + old_rec_id = rec.id + ds.remove_records(test_entries[2].name, "spec_1") + ds.submit(find_existing=False) + + rec = ds.get_record(test_entries[2].name, "spec_1") + assert rec.id != old_rec_id + record_count = len(ds.entry_names) * len(ds.specifications) assert ds.record_count == record_count assert ds._client.list_datasets()[0]["record_count"] == record_count diff --git a/qcportal/qcportal/gridoptimization/test_dataset_models.py b/qcportal/qcportal/gridoptimization/test_dataset_models.py index 8a292603b..f8e227d30 100644 --- a/qcportal/qcportal/gridoptimization/test_dataset_models.py +++ b/qcportal/qcportal/gridoptimization/test_dataset_models.py @@ -161,9 +161,13 @@ def test_gridoptimization_dataset_model_remove_record(snowflake_client: PortalCl ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_gridoptimization_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "gridoptimization", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_gridoptimization_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "gridoptimization", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/manybody/test_dataset_models.py b/qcportal/qcportal/manybody/test_dataset_models.py index 119dd8985..ee2ff9598 100644 --- a/qcportal/qcportal/manybody/test_dataset_models.py +++ b/qcportal/qcportal/manybody/test_dataset_models.py @@ -117,9 +117,13 @@ def test_manybody_dataset_model_remove_record(snowflake_client: PortalClient): ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_manybody_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "manybody", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_manybody_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "manybody", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/neb/test_dataset_models.py b/qcportal/qcportal/neb/test_dataset_models.py index e5c4ec587..668b8633c 100644 --- a/qcportal/qcportal/neb/test_dataset_models.py +++ b/qcportal/qcportal/neb/test_dataset_models.py @@ -163,9 +163,9 @@ def test_neb_dataset_model_remove_record(snowflake_client: PortalClient): ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_neb_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "neb", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_neb_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "neb", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low, owner_group="group1" ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/optimization/test_dataset_models.py b/qcportal/qcportal/optimization/test_dataset_models.py index 7a48e678d..473275992 100644 --- a/qcportal/qcportal/optimization/test_dataset_models.py +++ b/qcportal/qcportal/optimization/test_dataset_models.py @@ -119,9 +119,13 @@ def test_optimization_dataset_model_remove_record(snowflake_client: PortalClient ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_optimization_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "optimization", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_optimization_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "optimization", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/reaction/test_dataset_models.py b/qcportal/qcportal/reaction/test_dataset_models.py index 51ee7fd2b..a1701f00f 100644 --- a/qcportal/qcportal/reaction/test_dataset_models.py +++ b/qcportal/qcportal/reaction/test_dataset_models.py @@ -139,9 +139,13 @@ def test_reaction_dataset_model_remove_record(snowflake_client: PortalClient): ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_reaction_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "reaction", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_reaction_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "reaction", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/singlepoint/test_dataset_models.py b/qcportal/qcportal/singlepoint/test_dataset_models.py index f4c8907ab..7bd1fa957 100644 --- a/qcportal/qcportal/singlepoint/test_dataset_models.py +++ b/qcportal/qcportal/singlepoint/test_dataset_models.py @@ -104,9 +104,13 @@ def test_singlepoint_dataset_model_remove_record(snowflake_client: PortalClient) ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_singlepoint_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "singlepoint", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_singlepoint_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "singlepoint", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/torsiondrive/test_dataset_models.py b/qcportal/qcportal/torsiondrive/test_dataset_models.py index 7788d3c99..dffd10207 100644 --- a/qcportal/qcportal/torsiondrive/test_dataset_models.py +++ b/qcportal/qcportal/torsiondrive/test_dataset_models.py @@ -166,9 +166,13 @@ def test_torsiondrive_dataset_model_remove_record(snowflake_client: PortalClient ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_torsiondrive_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "torsiondrive", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_torsiondrive_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "torsiondrive", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) From 7c1d92551641964400ca0857135f5214d1d81ab9 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Mon, 13 Jan 2025 16:51:49 -0500 Subject: [PATCH 23/26] Test temporary_dir creation --- qcfractal/qcfractal/test_config.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/qcfractal/qcfractal/test_config.py b/qcfractal/qcfractal/test_config.py index f86bb080e..600525e4c 100644 --- a/qcfractal/qcfractal/test_config.py +++ b/qcfractal/qcfractal/test_config.py @@ -1,4 +1,5 @@ import copy +import os from qcfractal.config import FractalConfig @@ -69,3 +70,13 @@ def test_config_durations_dhms(tmp_path): assert cfg.internal_job_keep == 100807 assert cfg.api.jwt_access_token_expires == 7450 assert cfg.api.jwt_refresh_token_expires == 637277 + + +def test_config_tmpdir_create(tmp_path): + base_folder = str(tmp_path) + base_config = copy.deepcopy(_base_config) + base_config["temporary_dir"] = str(tmp_path / "qcatmpdir") + cfg = FractalConfig(base_folder=base_folder, **base_config) + + assert cfg.temporary_dir == str(tmp_path / "qcatmpdir") + assert os.path.exists(cfg.temporary_dir) From 03369fe1ab189bc499c9a6863d3844a1a653fc55 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Tue, 14 Jan 2025 09:14:46 -0500 Subject: [PATCH 24/26] Move update_nested_dict to utils & remove duplicate --- qcarchivetesting/qcarchivetesting/testing_classes.py | 3 ++- qcfractal/qcfractal/config.py | 11 +---------- qcfractal/qcfractal/snowflake.py | 3 ++- qcfractalcompute/qcfractalcompute/config.py | 11 +---------- qcfractalcompute/qcfractalcompute/testing_helpers.py | 2 +- qcportal/qcportal/utils.py | 9 +++++++++ 6 files changed, 16 insertions(+), 23 deletions(-) diff --git a/qcarchivetesting/qcarchivetesting/testing_classes.py b/qcarchivetesting/qcarchivetesting/testing_classes.py index cc7250cba..483d2bc45 100644 --- a/qcarchivetesting/qcarchivetesting/testing_classes.py +++ b/qcarchivetesting/qcarchivetesting/testing_classes.py @@ -4,13 +4,14 @@ from copy import deepcopy from qcarchivetesting import geoip_path, geoip_filename, ip_tests_enabled -from qcfractal.config import DatabaseConfig, update_nested_dict +from qcfractal.config import DatabaseConfig from qcfractal.db_socket import SQLAlchemySocket from qcfractal.postgres_harness import PostgresHarness, create_snowflake_postgres from qcfractal.snowflake import FractalSnowflake from qcportal import PortalClient, ManagerClient from qcportal.auth import UserInfo, GroupInfo from qcportal.managers import ManagerName +from qcportal.utils import update_nested_dict from .helpers import test_users, test_groups _activated_manager_programs = { diff --git a/qcfractal/qcfractal/config.py b/qcfractal/qcfractal/config.py index a8aa8e938..30672f696 100644 --- a/qcfractal/qcfractal/config.py +++ b/qcfractal/qcfractal/config.py @@ -22,16 +22,7 @@ from sqlalchemy.engine.url import URL, make_url from qcfractal.port_util import find_open_port -from qcportal.utils import duration_to_seconds - - -def update_nested_dict(d, u): - for k, v in u.items(): - if isinstance(v, dict): - d[k] = update_nested_dict(d.get(k, {}), v) - else: - d[k] = v - return d +from qcportal.utils import duration_to_seconds, update_nested_dict def _make_abs_path(path: Optional[str], base_folder: str, default_filename: Optional[str]) -> Optional[str]: diff --git a/qcfractal/qcfractal/snowflake.py b/qcfractal/qcfractal/snowflake.py index 73b632210..c5a95bf51 100644 --- a/qcfractal/qcfractal/snowflake.py +++ b/qcfractal/qcfractal/snowflake.py @@ -17,7 +17,8 @@ from qcportal import PortalClient from qcportal.record_models import RecordStatusEnum -from .config import FractalConfig, DatabaseConfig, update_nested_dict +from qcportal.utils import update_nested_dict +from .config import FractalConfig, DatabaseConfig from .flask_app.waitress_app import FractalWaitressApp from .job_runner import FractalJobRunner from .port_util import find_open_port diff --git a/qcfractalcompute/qcfractalcompute/config.py b/qcfractalcompute/qcfractalcompute/config.py index 11b27633b..58f6d8a8e 100644 --- a/qcfractalcompute/qcfractalcompute/config.py +++ b/qcfractalcompute/qcfractalcompute/config.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field, validator from typing_extensions import Literal -from qcportal.utils import seconds_to_hms, duration_to_seconds +from qcportal.utils import seconds_to_hms, duration_to_seconds, update_nested_dict def _make_abs_path(path: Optional[str], base_folder: str, default_filename: Optional[str]) -> Optional[str]: @@ -31,15 +31,6 @@ def _make_abs_path(path: Optional[str], base_folder: str, default_filename: Opti return os.path.abspath(path) -def update_nested_dict(d, u): - for k, v in u.items(): - if isinstance(v, dict): - d[k] = update_nested_dict(d.get(k, {}), v) - else: - d[k] = v - return d - - class PackageEnvironmentSettings(BaseModel): """ Environments with installed packages that can be used to run calculations diff --git a/qcfractalcompute/qcfractalcompute/testing_helpers.py b/qcfractalcompute/qcfractalcompute/testing_helpers.py index 64ddf241d..07ff91365 100644 --- a/qcfractalcompute/qcfractalcompute/testing_helpers.py +++ b/qcfractalcompute/qcfractalcompute/testing_helpers.py @@ -12,7 +12,6 @@ from qcfractal.components.optimization.testing_helpers import submit_test_data as submit_opt_test_data from qcfractal.components.singlepoint.testing_helpers import submit_test_data as submit_sp_test_data from qcfractal.config import FractalConfig -from qcfractal.config import update_nested_dict from qcfractal.db_socket import SQLAlchemySocket from qcfractalcompute.apps.models import AppTaskResult from qcfractalcompute.compress import compress_result @@ -20,6 +19,7 @@ from qcfractalcompute.config import FractalComputeConfig, FractalServerSettings, LocalExecutorConfig from qcportal.all_results import AllResultTypes, FailedOperation from qcportal.record_models import PriorityEnum, RecordTask +from qcportal.utils import update_nested_dict failed_op = FailedOperation( input_data=None, diff --git a/qcportal/qcportal/utils.py b/qcportal/qcportal/utils.py index a60f0ccb0..338acd173 100644 --- a/qcportal/qcportal/utils.py +++ b/qcportal/qcportal/utils.py @@ -440,3 +440,12 @@ def is_included(key: str, include: Optional[Iterable[str]], exclude: Optional[It exclude = tuple(sorted(exclude)) return _is_included(key, include, exclude, default) + + +def update_nested_dict(d: Dict[str, Any], u: Dict[str, Any]): + for k, v in u.items(): + if isinstance(v, dict): + d[k] = update_nested_dict(d.get(k, {}), v) + else: + d[k] = v + return d From e36fdba1816bd346b109762f6b185c934668eafd Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Tue, 14 Jan 2025 09:22:26 -0500 Subject: [PATCH 25/26] Remove unused start_api option to QCATesting snowflake --- qcarchivetesting/qcarchivetesting/testing_classes.py | 5 +---- qcfractal/qcfractal/components/auth/test_auth_permissions.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/qcarchivetesting/qcarchivetesting/testing_classes.py b/qcarchivetesting/qcarchivetesting/testing_classes.py index 483d2bc45..7cc2d7857 100644 --- a/qcarchivetesting/qcarchivetesting/testing_classes.py +++ b/qcarchivetesting/qcarchivetesting/testing_classes.py @@ -101,7 +101,6 @@ def __init__( self, pg_harness: QCATestingPostgresHarness, encoding: str, - start_api=True, create_users=False, enable_security=False, allow_unauthenticated_read=False, @@ -172,9 +171,7 @@ def __init__( if create_users: self.create_users() - # Start the flask api process if requested - if start_api: - self.start_api() + self.start_api() def create_users(self): # Get a storage socket and add the roles/users/passwords diff --git a/qcfractal/qcfractal/components/auth/test_auth_permissions.py b/qcfractal/qcfractal/components/auth/test_auth_permissions.py index 9c130858a..f5b04dbf2 100644 --- a/qcfractal/qcfractal/components/auth/test_auth_permissions.py +++ b/qcfractal/qcfractal/components/auth/test_auth_permissions.py @@ -19,7 +19,6 @@ def module_authtest_snowflake(postgres_server, pytestconfig): with QCATestingSnowflake( pg_harness, encoding, - start_api=True, create_users=False, enable_security=True, allow_unauthenticated_read=False, From 6f56eed84926a34ec7a32ef90c2c1ba2a7547eff Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Tue, 14 Jan 2025 09:27:06 -0500 Subject: [PATCH 26/26] Have separate function for generating test config --- .../qcarchivetesting/testing_fixtures.py | 47 +++++++++++-------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/qcarchivetesting/qcarchivetesting/testing_fixtures.py b/qcarchivetesting/qcarchivetesting/testing_fixtures.py index df7a5f1b6..a8d9a7b10 100644 --- a/qcarchivetesting/qcarchivetesting/testing_fixtures.py +++ b/qcarchivetesting/qcarchivetesting/testing_fixtures.py @@ -13,10 +13,36 @@ from qcfractal.db_socket.socket import SQLAlchemySocket from qcportal import PortalClient from qcportal.managers import ManagerName +from qcportal.utils import update_nested_dict from .helpers import geoip_path, geoip_filename, ip_tests_enabled, test_users from .testing_classes import QCATestingPostgresServer, QCATestingSnowflake, _activated_manager_programs +def _generate_default_config(pg_harness, extra_config=None) -> FractalConfig: + # Create a configuration. Since this is mostly just for a storage socket, + # We can use defaults for almost all, since a flask server, etc, won't be instantiated + # Also disable connection pooling in the storage socket + # (which can leave db connections open, causing problems when we go to delete + # the database) + cfg_dict = {} + cfg_dict["base_folder"] = pg_harness.config.base_folder + cfg_dict["loglevel"] = "DEBUG" + cfg_dict["database"] = pg_harness.config.dict() + cfg_dict["database"]["pool_size"] = 0 + cfg_dict["log_access"] = True + + if ip_tests_enabled: + cfg_dict["geoip2_dir"] = geoip_path + cfg_dict["geoip2_filename"] = geoip_filename + + cfg_dict["api"] = {"secret_key": secrets.token_urlsafe(32), "jwt_secret_key": secrets.token_urlsafe(32)} + + if extra_config: + cfg_dict = update_nested_dict(cfg_dict, extra_config) + + return FractalConfig(**cfg_dict) + + @pytest.fixture(scope="session") def postgres_server(tmp_path_factory): """ @@ -40,26 +66,7 @@ def session_storage_socket(postgres_server): """ pg_harness = postgres_server.get_new_harness("session_storage") - - # Create a configuration. Since this is mostly just for a storage socket, - # We can use defaults for almost all, since a flask server, etc, won't be instantiated - # Also disable connection pooling in the storage socket - # (which can leave db connections open, causing problems when we go to delete - # the database) - cfg_dict = {} - cfg_dict["base_folder"] = pg_harness.config.base_folder - cfg_dict["loglevel"] = "DEBUG" - cfg_dict["database"] = pg_harness.config.dict() - cfg_dict["database"]["pool_size"] = 0 - cfg_dict["log_access"] = True - - if ip_tests_enabled: - cfg_dict["geoip2_dir"] = geoip_path - cfg_dict["geoip2_filename"] = geoip_filename - - cfg_dict["api"] = {"secret_key": secrets.token_urlsafe(32), "jwt_secret_key": secrets.token_urlsafe(32)} - qcf_config = FractalConfig(**cfg_dict) - + qcf_config = _generate_default_config(pg_harness) socket = SQLAlchemySocket(qcf_config) # Create the template database for use in re-creating the database