diff --git a/qcarchivetesting/qcarchivetesting/testing_classes.py b/qcarchivetesting/qcarchivetesting/testing_classes.py index cc7250cba..7cc2d7857 100644 --- a/qcarchivetesting/qcarchivetesting/testing_classes.py +++ b/qcarchivetesting/qcarchivetesting/testing_classes.py @@ -4,13 +4,14 @@ from copy import deepcopy from qcarchivetesting import geoip_path, geoip_filename, ip_tests_enabled -from qcfractal.config import DatabaseConfig, update_nested_dict +from qcfractal.config import DatabaseConfig from qcfractal.db_socket import SQLAlchemySocket from qcfractal.postgres_harness import PostgresHarness, create_snowflake_postgres from qcfractal.snowflake import FractalSnowflake from qcportal import PortalClient, ManagerClient from qcportal.auth import UserInfo, GroupInfo from qcportal.managers import ManagerName +from qcportal.utils import update_nested_dict from .helpers import test_users, test_groups _activated_manager_programs = { @@ -100,7 +101,6 @@ def __init__( self, pg_harness: QCATestingPostgresHarness, encoding: str, - start_api=True, create_users=False, enable_security=False, allow_unauthenticated_read=False, @@ -171,9 +171,7 @@ def __init__( if create_users: self.create_users() - # Start the flask api process if requested - if start_api: - self.start_api() + self.start_api() def create_users(self): # Get a storage socket and add the roles/users/passwords diff --git a/qcarchivetesting/qcarchivetesting/testing_fixtures.py b/qcarchivetesting/qcarchivetesting/testing_fixtures.py index df7a5f1b6..a8d9a7b10 100644 --- a/qcarchivetesting/qcarchivetesting/testing_fixtures.py +++ b/qcarchivetesting/qcarchivetesting/testing_fixtures.py @@ -13,10 +13,36 @@ from qcfractal.db_socket.socket import SQLAlchemySocket from qcportal import PortalClient from qcportal.managers import ManagerName +from qcportal.utils import update_nested_dict from .helpers import geoip_path, geoip_filename, ip_tests_enabled, test_users from .testing_classes import QCATestingPostgresServer, QCATestingSnowflake, _activated_manager_programs +def _generate_default_config(pg_harness, extra_config=None) -> FractalConfig: + # Create a configuration. Since this is mostly just for a storage socket, + # We can use defaults for almost all, since a flask server, etc, won't be instantiated + # Also disable connection pooling in the storage socket + # (which can leave db connections open, causing problems when we go to delete + # the database) + cfg_dict = {} + cfg_dict["base_folder"] = pg_harness.config.base_folder + cfg_dict["loglevel"] = "DEBUG" + cfg_dict["database"] = pg_harness.config.dict() + cfg_dict["database"]["pool_size"] = 0 + cfg_dict["log_access"] = True + + if ip_tests_enabled: + cfg_dict["geoip2_dir"] = geoip_path + cfg_dict["geoip2_filename"] = geoip_filename + + cfg_dict["api"] = {"secret_key": secrets.token_urlsafe(32), "jwt_secret_key": secrets.token_urlsafe(32)} + + if extra_config: + cfg_dict = update_nested_dict(cfg_dict, extra_config) + + return FractalConfig(**cfg_dict) + + @pytest.fixture(scope="session") def postgres_server(tmp_path_factory): """ @@ -40,26 +66,7 @@ def session_storage_socket(postgres_server): """ pg_harness = postgres_server.get_new_harness("session_storage") - - # Create a configuration. Since this is mostly just for a storage socket, - # We can use defaults for almost all, since a flask server, etc, won't be instantiated - # Also disable connection pooling in the storage socket - # (which can leave db connections open, causing problems when we go to delete - # the database) - cfg_dict = {} - cfg_dict["base_folder"] = pg_harness.config.base_folder - cfg_dict["loglevel"] = "DEBUG" - cfg_dict["database"] = pg_harness.config.dict() - cfg_dict["database"]["pool_size"] = 0 - cfg_dict["log_access"] = True - - if ip_tests_enabled: - cfg_dict["geoip2_dir"] = geoip_path - cfg_dict["geoip2_filename"] = geoip_filename - - cfg_dict["api"] = {"secret_key": secrets.token_urlsafe(32), "jwt_secret_key": secrets.token_urlsafe(32)} - qcf_config = FractalConfig(**cfg_dict) - + qcf_config = _generate_default_config(pg_harness) socket = SQLAlchemySocket(qcf_config) # Create the template database for use in re-creating the database diff --git a/qcfractal/pyproject.toml b/qcfractal/pyproject.toml index 89fc563bc..92f999758 100644 --- a/qcfractal/pyproject.toml +++ b/qcfractal/pyproject.toml @@ -42,6 +42,9 @@ geoip = [ snowflake = [ "qcfractalcompute" ] +s3 = [ + "boto3" +] [project.urls] diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-02-e798462e0c03_add_repeat_delay_to_internal_jobs.py b/qcfractal/qcfractal/alembic/versions/2025-01-02-e798462e0c03_add_repeat_delay_to_internal_jobs.py index 7af6bc267..abc9d73d8 100644 --- a/qcfractal/qcfractal/alembic/versions/2025-01-02-e798462e0c03_add_repeat_delay_to_internal_jobs.py +++ b/qcfractal/qcfractal/alembic/versions/2025-01-02-e798462e0c03_add_repeat_delay_to_internal_jobs.py @@ -20,6 +20,19 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.add_column("internal_jobs", sa.Column("repeat_delay", sa.Integer(), nullable=True)) + + # Remove old periodic tasks + op.execute( + """DELETE FROM internal_jobs WHERE status IN ('waiting', 'running') AND name IN ( + 'delete_old_internal_jobs', + 'delete_old_access_log', + 'iterate_services', + 'geolocate_accesses', + 'check_manager_heartbeats', + 'update_geoip2_file' + ) + """ + ) # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-03-02afa97249c7_add_external_files_table.py b/qcfractal/qcfractal/alembic/versions/2025-01-03-02afa97249c7_add_external_files_table.py new file mode 100644 index 000000000..aec8be66c --- /dev/null +++ b/qcfractal/qcfractal/alembic/versions/2025-01-03-02afa97249c7_add_external_files_table.py @@ -0,0 +1,43 @@ +"""Add external files table + +Revision ID: 02afa97249c7 +Revises: c13116948b54 +Create Date: 2025-01-03 10:01:55.717905 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "02afa97249c7" +down_revision = "c13116948b54" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "external_file", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("created_on", sa.TIMESTAMP(), nullable=False), + sa.Column("status", sa.Enum("available", "processing", name="externalfilestatusenum"), nullable=False), + sa.Column("file_type", sa.Enum("dataset_attachment", name="externalfiletypeenum"), nullable=False), + sa.Column("bucket", sa.String(), nullable=False), + sa.Column("file_name", sa.String(), nullable=False), + sa.Column("object_key", sa.String(), nullable=False), + sa.Column("sha256sum", sa.String(), nullable=False), + sa.Column("file_size", sa.BigInteger(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("provenance", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("external_file") + # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-03-84285e3620fd_add_dataset_attachment_table.py b/qcfractal/qcfractal/alembic/versions/2025-01-03-84285e3620fd_add_dataset_attachment_table.py new file mode 100644 index 000000000..cc3ada892 --- /dev/null +++ b/qcfractal/qcfractal/alembic/versions/2025-01-03-84285e3620fd_add_dataset_attachment_table.py @@ -0,0 +1,38 @@ +"""Add dataset attachment table + +Revision ID: 84285e3620fd +Revises: 02afa97249c7 +Create Date: 2025-01-03 10:04:16.201770 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "84285e3620fd" +down_revision = "02afa97249c7" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "dataset_attachment", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("dataset_id", sa.Integer(), nullable=False), + sa.Column("attachment_type", sa.Enum("other", "view", name="datasetattachmenttype"), nullable=False), + sa.ForeignKeyConstraint(["dataset_id"], ["base_dataset.id"], ondelete="cascade"), + sa.ForeignKeyConstraint(["id"], ["external_file.id"], ondelete="cascade"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_dataset_attachment_dataset_id", "dataset_attachment", ["dataset_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_dataset_attachment_dataset_id", table_name="dataset_attachment") + op.drop_table("dataset_attachment") + # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-07-c13116948b54_add_progress_description.py b/qcfractal/qcfractal/alembic/versions/2025-01-07-c13116948b54_add_progress_description.py new file mode 100644 index 000000000..58d1d9149 --- /dev/null +++ b/qcfractal/qcfractal/alembic/versions/2025-01-07-c13116948b54_add_progress_description.py @@ -0,0 +1,28 @@ +"""Add progress description + +Revision ID: c13116948b54 +Revises: e798462e0c03 +Create Date: 2025-01-07 10:35:23.654928 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c13116948b54" +down_revision = "e798462e0c03" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("internal_jobs", sa.Column("progress_description", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("internal_jobs", "progress_description") + # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-09-5f6f804e11d3_add_dataset_internal_jobs.py b/qcfractal/qcfractal/alembic/versions/2025-01-09-5f6f804e11d3_add_dataset_internal_jobs.py new file mode 100644 index 000000000..647260f50 --- /dev/null +++ b/qcfractal/qcfractal/alembic/versions/2025-01-09-5f6f804e11d3_add_dataset_internal_jobs.py @@ -0,0 +1,35 @@ +"""Add dataset internal jobs + +Revision ID: 5f6f804e11d3 +Revises: 84285e3620fd +Create Date: 2025-01-09 16:25:50.187495 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5f6f804e11d3" +down_revision = "84285e3620fd" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "dataset_internal_job", + sa.Column("internal_job_id", sa.Integer(), nullable=False), + sa.Column("dataset_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["dataset_id"], ["base_dataset.id"], ondelete="cascade"), + sa.ForeignKeyConstraint(["internal_job_id"], ["internal_jobs.id"], ondelete="cascade"), + sa.PrimaryKeyConstraint("internal_job_id", "dataset_id"), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("dataset_internal_job") + # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/alembic/versions/2025-01-10-3690c677f8d1_add_internal_job_serial_group.py b/qcfractal/qcfractal/alembic/versions/2025-01-10-3690c677f8d1_add_internal_job_serial_group.py new file mode 100644 index 000000000..9b4cdd8d4 --- /dev/null +++ b/qcfractal/qcfractal/alembic/versions/2025-01-10-3690c677f8d1_add_internal_job_serial_group.py @@ -0,0 +1,40 @@ +"""Add internal job serial group + +Revision ID: 3690c677f8d1 +Revises: 5f6f804e11d3 +Create Date: 2025-01-10 16:08:36.541807 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "3690c677f8d1" +down_revision = "5f6f804e11d3" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("internal_jobs", sa.Column("serial_group", sa.String(), nullable=True)) + op.create_index( + "ux_internal_jobs_status_serial_group", + "internal_jobs", + ["status", "serial_group"], + unique=True, + postgresql_where=sa.text("status = 'running'"), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index( + "ux_internal_jobs_status_serial_group", + table_name="internal_jobs", + postgresql_where=sa.text("status = 'running'"), + ) + op.drop_column("internal_jobs", "serial_group") + # ### end Alembic commands ### diff --git a/qcfractal/qcfractal/components/auth/test_auth_permissions.py b/qcfractal/qcfractal/components/auth/test_auth_permissions.py index 9c130858a..f5b04dbf2 100644 --- a/qcfractal/qcfractal/components/auth/test_auth_permissions.py +++ b/qcfractal/qcfractal/components/auth/test_auth_permissions.py @@ -19,7 +19,6 @@ def module_authtest_snowflake(postgres_server, pytestconfig): with QCATestingSnowflake( pg_harness, encoding, - start_api=True, create_users=False, enable_security=True, allow_unauthenticated_read=False, diff --git a/qcfractal/qcfractal/components/create_view.py b/qcfractal/qcfractal/components/create_view.py deleted file mode 100644 index 6fceb99eb..000000000 --- a/qcfractal/qcfractal/components/create_view.py +++ /dev/null @@ -1,132 +0,0 @@ -import os - -from sqlalchemy import select, create_engine, Column, String, ForeignKey, LargeBinary -from sqlalchemy.orm import selectinload, sessionmaker, declarative_base - -from qcfractal.db_socket.socket import SQLAlchemySocket -from qcportal.compression import compress, CompressionEnum -from qcportal.serialization import serialize - -ViewBaseORM = declarative_base() - - -class DatasetViewEntry(ViewBaseORM): - __tablename__ = "dataset_entry" - name = Column(String, primary_key=True) - data = Column(LargeBinary, nullable=False) - - -class DatasetViewSpecification(ViewBaseORM): - __tablename__ = "dataset_specification" - name = Column(String, primary_key=True) - data = Column(LargeBinary, nullable=False) - - -class DatasetViewRecord(ViewBaseORM): - __tablename__ = "dataset_record" - entry_name = Column(String, ForeignKey(DatasetViewEntry.name), primary_key=True) - specification_name = Column(String, ForeignKey(DatasetViewSpecification.name), primary_key=True) - data = Column(LargeBinary, nullable=False) - - -class DatasetViewMetadata(ViewBaseORM): - __tablename__ = "dataset_metadata" - - key = Column(String, ForeignKey(DatasetViewEntry.name), primary_key=True) - value = Column(LargeBinary, nullable=False) - - -def _serialize_orm(orm, exclude=None): - s_data = serialize(orm.model_dict(exclude=exclude), "application/msgpack") - c_data, _, _ = compress(s_data, CompressionEnum.zstd, 7) - return c_data - - -def create_dataset_view(dataset_id: int, socket: SQLAlchemySocket, view_file_path: str): - if os.path.exists(view_file_path): - raise RuntimeError(f"File {view_file_path} exists - will not overwrite") - - if os.path.isdir(view_file_path): - raise RuntimeError(f"{view_file_path} is a directory") - - uri = "sqlite:///" + view_file_path - engine = create_engine(uri) - ViewSession = sessionmaker(bind=engine) - - ViewBaseORM.metadata.create_all(engine) - - view_session = ViewSession() - - with socket.session_scope(True) as fractal_session: - ds_type = socket.datasets.lookup_type(dataset_id) - ds_socket = socket.datasets.get_socket(ds_type) - - dataset_orm = ds_socket.dataset_orm - entry_orm = ds_socket.entry_orm - specification_orm = ds_socket.specification_orm - record_item_orm = ds_socket.record_item_orm - - stmt = select(dataset_orm).where(dataset_orm.id == dataset_id) - stmt = stmt.options(selectinload("*")) - ds_orm = fractal_session.execute(stmt).scalar_one() - - # Metadata - metadata_bytes = _serialize_orm(ds_orm) - metadata_orm = DatasetViewMetadata(key="raw_data", value=metadata_bytes) - view_session.add(metadata_orm) - view_session.commit() - - # Entries - stmt = select(entry_orm) - stmt = stmt.options(selectinload("*")) - stmt = stmt.where(entry_orm.dataset_id == dataset_id) - entries = fractal_session.execute(stmt).scalars().all() - - for entry in entries: - entry_bytes = _serialize_orm(entry) - entry_orm = DatasetViewEntry(name=entry.name, data=entry_bytes) - view_session.add(entry_orm) - - view_session.commit() - - # Specifications - stmt = select(specification_orm) - stmt = stmt.options(selectinload("*")) - stmt = stmt.where(specification_orm.dataset_id == dataset_id) - specs = fractal_session.execute(stmt).scalars().all() - - for spec in specs: - spec_bytes = _serialize_orm(spec) - specification_orm = DatasetViewSpecification(name=spec.name, data=spec_bytes) - view_session.add(specification_orm) - - view_session.commit() - - base_stmt = select(record_item_orm).where(record_item_orm.dataset_id == dataset_id) - base_stmt = base_stmt.options(selectinload("*")) - base_stmt = base_stmt.order_by(record_item_orm.record_id.asc()) - - skip = 0 - while True: - stmt = base_stmt.offset(skip).limit(10) - batch = fractal_session.execute(stmt).scalars() - - count = 0 - for item_orm in batch: - item_bytes = _serialize_orm(item_orm) - - view_record_orm = DatasetViewRecord( - entry_name=item_orm.entry_name, - specification_name=item_orm.specification_name, - data=item_bytes, - ) - - view_session.add(view_record_orm) - count += 1 - - view_session.commit() - - if count == 0: - break - - skip += count diff --git a/qcfractal/qcfractal/components/dataset_db_models.py b/qcfractal/qcfractal/components/dataset_db_models.py index 261f7b877..173fd7748 100644 --- a/qcfractal/qcfractal/components/dataset_db_models.py +++ b/qcfractal/qcfractal/components/dataset_db_models.py @@ -13,12 +13,16 @@ ForeignKey, ForeignKeyConstraint, UniqueConstraint, + Enum, ) from sqlalchemy.orm import relationship from sqlalchemy.orm.collections import attribute_keyed_dict from qcfractal.components.auth.db_models import UserIDMapSubquery, GroupIDMapSubquery, UserORM, GroupORM +from qcfractal.components.external_files.db_models import ExternalFileORM +from qcfractal.components.internal_jobs.db_models import InternalJobORM from qcfractal.db_socket import BaseORM, MsgpackExt +from qcportal.dataset_models import DatasetAttachmentType class BaseDatasetORM(BaseORM): @@ -78,6 +82,12 @@ class BaseDatasetORM(BaseORM): passive_deletes=True, ) + attachments = relationship( + "DatasetAttachmentORM", + cascade="all, delete-orphan", + passive_deletes=True, + ) + __table_args__ = ( UniqueConstraint("dataset_type", "lname", name="ux_base_dataset_dataset_type_lname"), Index("ix_base_dataset_dataset_type", "dataset_type"), @@ -129,3 +139,25 @@ class ContributedValuesORM(BaseORM): __table_args__ = (Index("ix_contributed_values_dataset_id", "dataset_id"),) _qcportal_model_excludes = ["dataset_id"] + + +class DatasetInternalJobORM(BaseORM): + __tablename__ = "dataset_internal_job" + + internal_job_id = Column(Integer, ForeignKey(InternalJobORM.id, ondelete="cascade"), primary_key=True) + dataset_id = Column(Integer, ForeignKey("base_dataset.id", ondelete="cascade"), primary_key=True) + + +class DatasetAttachmentORM(ExternalFileORM): + __tablename__ = "dataset_attachment" + + id = Column(Integer, ForeignKey(ExternalFileORM.id, ondelete="cascade"), primary_key=True) + dataset_id = Column(Integer, ForeignKey("base_dataset.id", ondelete="cascade"), nullable=False) + + attachment_type = Column(Enum(DatasetAttachmentType), nullable=False) + + __mapper_args__ = {"polymorphic_identity": "dataset_attachment"} + + __table_args__ = (Index("ix_dataset_attachment_dataset_id", "dataset_id"),) + + _qcportal_model_excludes = ["dataset_id"] diff --git a/qcfractal/qcfractal/components/dataset_processing/__init__.py b/qcfractal/qcfractal/components/dataset_processing/__init__.py new file mode 100644 index 000000000..4e62cdf45 --- /dev/null +++ b/qcfractal/qcfractal/components/dataset_processing/__init__.py @@ -0,0 +1 @@ +from .views import create_view_file diff --git a/qcfractal/qcfractal/components/dataset_processing/views.py b/qcfractal/qcfractal/components/dataset_processing/views.py new file mode 100644 index 000000000..d85d43b51 --- /dev/null +++ b/qcfractal/qcfractal/components/dataset_processing/views.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import os +from collections import defaultdict +from typing import TYPE_CHECKING + +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from qcfractal.components.internal_jobs.status import JobProgress +from qcfractal.components.record_db_models import BaseRecordORM +from qcportal.cache import DatasetCache +from qcportal.dataset_models import BaseDataset +from qcportal.record_models import RecordStatusEnum, BaseRecord +from qcportal.utils import chunk_iterable + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + from qcfractal.db_socket.socket import SQLAlchemySocket + from typing import Optional, Iterable + from typing import Iterable + from sqlalchemy.orm.session import Session + + +def create_view_file( + session: Session, + socket: SQLAlchemySocket, + dataset_id: int, + dataset_type: str, + output_path: str, + status: Optional[Iterable[RecordStatusEnum]] = None, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + *, + include_children: bool = True, + job_progress: Optional[JobProgress] = None, +): + """ + Creates a view file for a dataset + + Note: the job progress object will be filled to 90% to leave room for uploading + + Parameters + ---------- + session + An existing SQLAlchemy session to use. + socket + Full SQLAlchemy socket to use for getting records + dataset_type + Type of the underlying dataset to create a view of (as a string) + dataset_id + ID of the dataset to create the view for + output_path + Full path (including filename) to output the view data to. Must not already exist + status + List of statuses to include. Default is to include records with any status + include + List of specific record fields to include in the export. Default is to include most fields + exclude + List of specific record fields to exclude from the export. Defaults to excluding none. + include_children + Specifies whether child records associated with the main records should also be included (recursively) + in the view file. + job_progress + Object used to track the progress of the job + """ + + if os.path.exists(output_path): + raise RuntimeError(f"File {output_path} exists - will not overwrite") + + if os.path.isdir(output_path): + raise RuntimeError(f"{output_path} is a directory") + + ds_socket = socket.datasets.get_socket(dataset_type) + ptl_dataset_type = BaseDataset.get_subclass(dataset_type) + + ptl_entry_type = ptl_dataset_type._entry_type + ptl_specification_type = ptl_dataset_type._specification_type + + view_db = DatasetCache(output_path, read_only=False, dataset_type=ptl_dataset_type) + + stmt = select(ds_socket.dataset_orm).where(ds_socket.dataset_orm.id == dataset_id) + stmt = stmt.options(selectinload("*")) + ds_orm = session.execute(stmt).scalar_one() + + # Metadata + view_db.update_metadata("dataset_metadata", ds_orm.model_dict()) + + # Entries + if job_progress is not None: + job_progress.raise_if_cancelled() + job_progress.update_progress(0, "Processing dataset entries") + + stmt = select(ds_socket.entry_orm) + stmt = stmt.options(selectinload("*")) + stmt = stmt.where(ds_socket.entry_orm.dataset_id == dataset_id) + + entries = session.execute(stmt).scalars().all() + entries = [e.to_model(ptl_entry_type) for e in entries] + view_db.update_entries(entries) + + if job_progress is not None: + job_progress.raise_if_cancelled() + job_progress.update_progress(5, "Processing dataset specifications") + + # Specifications + stmt = select(ds_socket.specification_orm) + stmt = stmt.options(selectinload("*")) + stmt = stmt.where(ds_socket.specification_orm.dataset_id == dataset_id) + + specs = session.execute(stmt).scalars().all() + specs = [s.to_model(ptl_specification_type) for s in specs] + view_db.update_specifications(specs) + + if job_progress is not None: + job_progress.raise_if_cancelled() + job_progress.update_progress(10, "Loading record information") + + # Now all the records + stmt = select(ds_socket.record_item_orm).where(ds_socket.record_item_orm.dataset_id == dataset_id) + stmt = stmt.order_by(ds_socket.record_item_orm.record_id.asc()) + record_items = session.execute(stmt).scalars().all() + + record_ids = set(ri.record_id for ri in record_items) + all_ids = set(record_ids) + + if include_children: + # Get all the children ids + children_ids = socket.records.get_children_ids(session, record_ids) + all_ids |= set(children_ids) + + ############################################################################ + # Determine the record types of all the ids (top-level and children if desired) + ############################################################################ + stmt = select(BaseRecordORM.id, BaseRecordORM.record_type) + + if status is not None: + stmt = stmt.where(BaseRecordORM.status.in_(status)) + + # Sort into a dictionary with keys being the record type + record_type_map = defaultdict(list) + + for id_chunk in chunk_iterable(all_ids, 500): + stmt2 = stmt.where(BaseRecordORM.id.in_(id_chunk)) + for record_id, record_type in session.execute(stmt2).yield_per(100): + record_type_map[record_type].append(record_id) + + if job_progress is not None: + job_progress.raise_if_cancelled() + job_progress.update_progress(15, "Processing individual records") + + ############################################################################ + # Actually fetch the record data now + # We go one over the different types of records, then load them in batches + ############################################################################ + record_count = len(all_ids) + finished_count = 0 + + for record_type_str, record_ids in record_type_map.items(): + record_socket = socket.records.get_socket(record_type_str) + record_type = BaseRecord.get_subclass(record_type_str) + + for id_chunk in chunk_iterable(record_ids, 200): + record_dicts = record_socket.get(id_chunk, include=include, exclude=exclude, session=session) + record_data = [record_type(**r) for r in record_dicts] + view_db.update_records(record_data) + + finished_count += len(id_chunk) + if job_progress is not None: + job_progress.raise_if_cancelled() + + # Fraction of the 75% left over (15 to start, 10 left over for uploading) + job_progress.update_progress( + 15 + int(75 * finished_count / record_count), "Processing individual records" + ) + + # Update the dataset <-> record association + record_info = [(ri.entry_name, ri.specification_name, ri.record_id) for ri in record_items] + view_db.update_dataset_records(record_info) diff --git a/qcfractal/qcfractal/components/dataset_routes.py b/qcfractal/qcfractal/components/dataset_routes.py index e39381387..d53266eee 100644 --- a/qcfractal/qcfractal/components/dataset_routes.py +++ b/qcfractal/qcfractal/components/dataset_routes.py @@ -1,4 +1,4 @@ -from typing import Dict, Any +from typing import Dict from flask import current_app, g @@ -12,6 +12,7 @@ DatasetFetchRecordsBody, DatasetFetchEntryBody, DatasetFetchSpecificationBody, + DatasetCreateViewBody, DatasetSubmitBody, DatasetDeleteStrBody, DatasetRecordModifyBody, @@ -21,6 +22,7 @@ DatasetQueryRecords, DatasetDeleteParams, DatasetModifyEntryBody, + DatasetGetInternalJobParams, ) from qcportal.exceptions import LimitExceededError @@ -37,7 +39,14 @@ def get_general_dataset_v1(dataset_id: int, url_params: ProjURLParameters): with storage_socket.session_scope(True) as session: ds_type = storage_socket.datasets.lookup_type(dataset_id, session=session) ds_socket = storage_socket.datasets.get_socket(ds_type) - return ds_socket.get(dataset_id, url_params.include, url_params.exclude, session=session) + + r = ds_socket.get(dataset_id, url_params.include, url_params.exclude, session=session) + + # TODO - remove this eventually + # Don't return attachments by default + r.pop("attachments", None) + + return r @api_v1.route("/datasets/query", methods=["POST"]) @@ -161,6 +170,24 @@ def modify_dataset_metadata_v1(dataset_type: str, dataset_id: int, body_data: Da return ds_socket.update_metadata(dataset_id, new_metadata=body_data) +######################### +# Views & Attachments +######################### +@api_v1.route("/datasets///create_view", methods=["POST"]) +@wrap_route("WRITE") +def create_dataset_view_v1(dataset_type: str, dataset_id: int, body_data: DatasetCreateViewBody): + return storage_socket.datasets.add_create_view_attachment_job( + dataset_id, + dataset_type, + description=body_data.description, + provenance=body_data.provenance, + status=body_data.status, + include=body_data.include, + exclude=body_data.exclude, + include_children=body_data.include_children, + ) + + ######################### # Computation submission ######################### @@ -342,9 +369,39 @@ def revert_dataset_records_v1(dataset_type: str, dataset_id: int, body_data: Dat ) -################### +################################# +# Internal Jobs +################################# +@api_v1.route("/datasets//internal_jobs/", methods=["GET"]) +@wrap_route("READ") +def get_dataset_internal_job_v1(dataset_id: int, job_id: int): + return storage_socket.datasets.get_internal_job(dataset_id, job_id) + + +@api_v1.route("/datasets//internal_jobs", methods=["GET"]) +@wrap_route("READ") +def list_dataset_internal_jobs_v1(dataset_id: int, url_params: DatasetGetInternalJobParams): + return storage_socket.datasets.list_internal_jobs(dataset_id, status=url_params.status) + + +################################# +# Attachments +################################# +@api_v1.route("/datasets//attachments", methods=["GET"]) +@wrap_route("READ") +def fetch_dataset_attachments_v1(dataset_id: int): + return storage_socket.datasets.get_attachments(dataset_id) + + +@api_v1.route("/datasets//attachments/", methods=["DELETE"]) +@wrap_route("DELETE") +def delete_dataset_attachment_v1(dataset_id: int, attachment_id: int): + return storage_socket.datasets.delete_attachment(dataset_id, attachment_id) + + +################################# # Contributed Values -################### +################################# @api_v1.route("/datasets//contributed_values", methods=["GET"]) @wrap_route("READ") def fetch_dataset_contributed_values_v1(dataset_id: int): diff --git a/qcfractal/qcfractal/components/dataset_socket.py b/qcfractal/qcfractal/components/dataset_socket.py index 986ee1d71..4d3aff78f 100644 --- a/qcfractal/qcfractal/components/dataset_socket.py +++ b/qcfractal/qcfractal/components/dataset_socket.py @@ -1,29 +1,45 @@ from __future__ import annotations import logging +import os +import tempfile from typing import TYPE_CHECKING from sqlalchemy import select, delete, func, union, text, and_ +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import load_only, lazyload, joinedload, with_polymorphic from sqlalchemy.orm.attributes import flag_modified -from qcfractal.components.dataset_db_models import BaseDatasetORM, ContributedValuesORM +from qcfractal.components.dataset_db_models import ( + BaseDatasetORM, + ContributedValuesORM, + DatasetAttachmentORM, + DatasetInternalJobORM, +) +from qcfractal.components.dataset_processing import create_view_file +from qcfractal.components.internal_jobs.db_models import InternalJobORM from qcfractal.components.record_db_models import BaseRecordORM from qcfractal.db_socket.helpers import ( get_general, get_query_proj_options, ) -from qcportal.exceptions import AlreadyExistsError, MissingDataError +from qcportal.dataset_models import DatasetAttachmentType +from qcportal.exceptions import AlreadyExistsError, MissingDataError, UserReportableError +from qcportal.internal_jobs import InternalJobStatusEnum from qcportal.metadata_models import InsertMetadata, DeleteMetadata, UpdateMetadata from qcportal.record_models import RecordStatusEnum, PriorityEnum -from qcportal.utils import chunk_iterable +from qcportal.utils import chunk_iterable, now_at_utc if TYPE_CHECKING: from sqlalchemy.orm.session import Session from qcportal.dataset_models import DatasetModifyMetadata + from qcfractal.components.internal_jobs.status import JobProgress from qcfractal.db_socket.socket import SQLAlchemySocket from qcfractal.db_socket.base_orm import BaseORM from typing import Dict, Any, Optional, Sequence, Iterable, Tuple, List, Union + from typing import Iterable, List, Dict, Any + from qcfractal.db_socket.socket import SQLAlchemySocket + from sqlalchemy.orm.session import Session class BaseDatasetSocket: @@ -56,6 +72,8 @@ def __init__( # Use the identity from the ORM object. This keeps everything consistent self.dataset_type = self.dataset_orm.__mapper_args__["polymorphic_identity"] + self._logger = logging.getLogger(__name__) + def _add_specification(self, session, specification) -> Tuple[InsertMetadata, Optional[int]]: raise NotImplementedError("_add_specification must be overridden by the derived class") @@ -1649,3 +1667,277 @@ def get_contributed_values(self, dataset_id: int, *, session: Optional[Session] with self.root_socket.optional_session(session, True) as session: cv = session.execute(stmt).scalars().all() return [x.model_dict() for x in cv] + + def get_internal_job( + self, + dataset_id: int, + job_id: int, + *, + session: Optional[Session] = None, + ): + stmt = select(InternalJobORM) + stmt = stmt.join(DatasetInternalJobORM, DatasetInternalJobORM.internal_job_id == InternalJobORM.id) + stmt = stmt.where(DatasetInternalJobORM.dataset_id == dataset_id) + stmt = stmt.where(DatasetInternalJobORM.internal_job_id == job_id) + + with self.root_socket.optional_session(session, True) as session: + ij_orm = session.execute(stmt).scalar_one_or_none() + if ij_orm is None: + raise MissingDataError(f"Job id {job_id} not found in dataset {dataset_id}") + return ij_orm.model_dict() + + def list_internal_jobs( + self, + dataset_id: int, + status: Optional[Iterable[InternalJobStatusEnum]] = None, + *, + session: Optional[Session] = None, + ): + stmt = select(InternalJobORM) + stmt = stmt.join(DatasetInternalJobORM, DatasetInternalJobORM.internal_job_id == InternalJobORM.id) + stmt = stmt.where(DatasetInternalJobORM.dataset_id == dataset_id) + + if status is not None: + stmt = stmt.where(InternalJobORM.status.in_(status)) + + with self.root_socket.optional_session(session, True) as session: + ij_orm = session.execute(stmt).scalars().all() + return [i.model_dict() for i in ij_orm] + + def get_attachments(self, dataset_id: int, *, session: Optional[Session] = None) -> List[Dict[str, Any]]: + """ + Get the attachments for a dataset + """ + + stmt = select(DatasetAttachmentORM) + stmt = stmt.where(DatasetAttachmentORM.dataset_id == dataset_id) + + with self.root_socket.optional_session(session, True) as session: + att = session.execute(stmt).scalars().all() + return [x.model_dict() for x in att] + + def delete_attachment(self, dataset_id: int, file_id: int, *, session: Optional[Session] = None): + stmt = select(DatasetAttachmentORM) + stmt = stmt.where(DatasetAttachmentORM.dataset_id == dataset_id) + stmt = stmt.where(DatasetAttachmentORM.id == file_id) + stmt = stmt.with_for_update() + + with self.root_socket.optional_session(session) as session: + att = session.execute(stmt).scalar_one_or_none() + if att is None: + raise MissingDataError(f"Attachment with file id {file_id} not found in dataset {dataset_id}") + + return self.root_socket.external_files.delete(file_id, session=session) + + def attach_file( + self, + dataset_id: int, + attachment_type: DatasetAttachmentType, + file_path: str, + file_name: str, + description: Optional[str], + provenance: Dict[str, Any], + *, + job_progress: Optional[JobProgress] = None, + session: Optional[Session] = None, + ) -> int: + """ + Attach a file to a dataset + + This function uploads the specified file and associates it with the dataset by creating + a corresponding dataset attachment record. This operation requires S3 storage to be enabled. + + Parameters + ---------- + dataset_id + The ID of the dataset to which the file will be attached. + attachment_type + The type of attachment that categorizes the file being added. + file_path + The local file system path to the file that needs to be uploaded. + file_name + The name of the file to be used in the attachment record. This is the filename that is + recommended to the user by default. + description + An optional description of the file + provenance + A dictionary containing metadata regarding the origin or history of the file. + job_progress + Object used to track progress if this function is being run in a background job + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Raises + ------ + UserReportableError + Raised if S3 storage is not enabled + """ + + if not self.root_socket.qcf_config.s3.enabled: + raise UserReportableError("S3 storage is not enabled. Can not attach file to a dataset") + + self._logger.info(f"Uploading/Attaching dataset-related file: {file_path}") + with self.root_socket.optional_session(session) as session: + ef = DatasetAttachmentORM( + dataset_id=dataset_id, + attachment_type=attachment_type, + file_name=file_name, + description=description, + provenance=provenance, + ) + + file_id = self.root_socket.external_files.add_file( + file_path, ef, session=session, job_progress=job_progress + ) + + self._logger.info(f"Dataset attachment {file_path} successfully uploaded to S3. ID is {file_id}") + return file_id + + def create_view_attachment( + self, + dataset_id: int, + dataset_type: str, + description: Optional[str], + provenance: Dict[str, Any], + status: Optional[Iterable[RecordStatusEnum]] = None, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + *, + include_children: bool = True, + job_progress: Optional[JobProgress] = None, + session: Optional[Session] = None, + ): + """ + Creates a dataset view and attaches it to the dataset + + Uses a temporary directory within the globally-configured `temporary_dir` + + Parameters + ---------- + dataset_id : int + ID of the dataset to create the view for + dataset_type + Type of dataset the ID is + description + Optional string describing the view file + provenance + Dictionary with any metadata or other information about the view. Information regarding + the options used to create the view will be added. + status + List of statuses to include. Default is to include records with any status + include + List of specific record fields to include in the export. Default is to include most fields + exclude + List of specific record fields to exclude from the export. Defaults to excluding none. + include_children + Specifies whether child records associated with the main records should also be included (recursively) + in the view file. + job_progress + Object used to track progress if this function is being run in a background job + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + """ + + if not self.root_socket.qcf_config.s3.enabled: + raise UserReportableError("S3 storage is not enabled. Can not not create view") + + # Add the options for the view creation to the provenance + provenance = provenance | { + "options": { + "status": status, + "include": include, + "exclude": exclude, + "include_children": include_children, + } + } + + with tempfile.TemporaryDirectory(dir=self.root_socket.qcf_config.temporary_dir) as tmpdir: + self._logger.info(f"Using temporary directory {tmpdir} for view creation") + + file_name = f"dataset_{dataset_id}_view.sqlite" + tmp_file_path = os.path.join(tmpdir, file_name) + + create_view_file( + session, + self.root_socket, + dataset_id, + dataset_type, + tmp_file_path, + status=status, + include=include, + exclude=exclude, + include_children=include_children, + job_progress=job_progress, + ) + + self._logger.info(f"View file created. File size is {os.path.getsize(tmp_file_path)/1048576} MiB.") + + if job_progress is not None: + job_progress.update_progress(90, "Uploading view file to S3") + + file_id = self.attach_file( + dataset_id, DatasetAttachmentType.view, tmp_file_path, file_name, description, provenance + ) + + if job_progress is not None: + job_progress.update_progress(100) + + return file_id + + def add_create_view_attachment_job( + self, + dataset_id: int, + dataset_type: str, + description: str, + provenance: Dict[str, Any], + status: Optional[Iterable[RecordStatusEnum]] = None, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + *, + include_children: bool = True, + session: Optional[Session] = None, + ) -> int: + """ + Creates an internal job for creating and attaching a view to a dataset + + See :ref:`create_view_attachment` for a description of the parameters + + Returns + ------- + : + ID of the created internal job + """ + + if not self.root_socket.qcf_config.s3.enabled: + raise UserReportableError("S3 storage is not enabled. Can not not create view") + + with self.root_socket.optional_session(session) as session: + job_id = self.root_socket.internal_jobs.add( + f"create_attach_view_ds_{dataset_id}", + now_at_utc(), + f"datasets.create_view_attachment", + { + "dataset_id": dataset_id, + "dataset_type": dataset_type, + "description": description, + "provenance": provenance, + "status": status, + "include": include, + "exclude": exclude, + "include_children": include_children, + }, + user_id=None, + unique_name=True, + serial_group="ds_create_view", + session=session, + ) + + stmt = ( + insert(DatasetInternalJobORM) + .values(dataset_id=dataset_id, internal_job_id=job_id) + .on_conflict_do_nothing() + ) + session.execute(stmt) + return job_id diff --git a/qcfractal/qcfractal/components/external_files/__init__.py b/qcfractal/qcfractal/components/external_files/__init__.py new file mode 100644 index 000000000..ee9c19c68 --- /dev/null +++ b/qcfractal/qcfractal/components/external_files/__init__.py @@ -0,0 +1 @@ +from .socket import ExternalFileSocket diff --git a/qcfractal/qcfractal/components/external_files/db_models.py b/qcfractal/qcfractal/components/external_files/db_models.py new file mode 100644 index 000000000..7b6d776e9 --- /dev/null +++ b/qcfractal/qcfractal/components/external_files/db_models.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from sqlalchemy import Column, Integer, String, Enum, TIMESTAMP, BigInteger +from sqlalchemy.dialects.postgresql import JSONB + +from qcfractal.db_socket.base_orm import BaseORM +from qcportal.external_files import ExternalFileStatusEnum, ExternalFileTypeEnum +from qcportal.utils import now_at_utc + + +class ExternalFileORM(BaseORM): + """ + Table for storing molecules + """ + + __tablename__ = "external_file" + + id = Column(Integer, primary_key=True) + file_type = Column(Enum(ExternalFileTypeEnum), nullable=False) + + created_on = Column(TIMESTAMP, default=now_at_utc, nullable=False) + status = Column(Enum(ExternalFileStatusEnum), nullable=False) + + file_name = Column(String, nullable=False) + description = Column(String, nullable=True) + provenance = Column(JSONB, nullable=False) + + sha256sum = Column(String, nullable=False) + file_size = Column(BigInteger, nullable=False) + + bucket = Column(String, nullable=False) + object_key = Column(String, nullable=False) + + __mapper_args__ = {"polymorphic_on": "file_type"} + + _qcportal_model_excludes__ = ["object_key", "bucket"] diff --git a/qcfractal/qcfractal/components/external_files/routes.py b/qcfractal/qcfractal/components/external_files/routes.py new file mode 100644 index 000000000..6af0f0d6a --- /dev/null +++ b/qcfractal/qcfractal/components/external_files/routes.py @@ -0,0 +1,29 @@ +from flask import redirect, Response, stream_with_context, current_app + +from qcfractal.flask_app import storage_socket +from qcfractal.flask_app.api_v1.blueprint import api_v1 +from qcfractal.flask_app.api_v1.helpers import wrap_route + + +@api_v1.route("/external_files/", methods=["GET"]) +@wrap_route("READ") +def get_external_file_metadata_v1(file_id: int): + return storage_socket.external_files.get_metadata(file_id) + + +@api_v1.route("/external_files//download", methods=["GET"]) +@wrap_route("READ") +def download_external_file_v1(file_id: int): + passthrough = current_app.config["QCFRACTAL_CONFIG"].s3.passthrough + + if passthrough: + file_name, streamer_func = storage_socket.external_files.get_file_streamer(file_id) + + return Response( + stream_with_context(streamer_func()), + content_type="application/octet-stream", + headers={"Content-Disposition": f'attachment; filename="{file_name}"'}, + ) + else: + _, url = storage_socket.external_files.get_url(file_id) + return redirect(url, code=302) diff --git a/qcfractal/qcfractal/components/external_files/socket.py b/qcfractal/qcfractal/components/external_files/socket.py new file mode 100644 index 000000000..f939648bd --- /dev/null +++ b/qcfractal/qcfractal/components/external_files/socket.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +import hashlib +import importlib +import logging +import os +import uuid +from typing import TYPE_CHECKING + +from sqlalchemy import select + +from qcportal.exceptions import MissingDataError +from qcportal.external_files import ExternalFileTypeEnum, ExternalFileStatusEnum +from .db_models import ExternalFileORM + +# Boto3 package is optional +_boto3_spec = importlib.util.find_spec("boto3") + +if _boto3_spec is not None: + boto3 = importlib.util.module_from_spec(_boto3_spec) + _boto3_spec.loader.exec_module(boto3) + + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + from qcfractal.components.internal_jobs.status import JobProgress + from qcfractal.db_socket.socket import SQLAlchemySocket + from typing import Optional, Dict, Any, Tuple, Union, BinaryIO, Callable, Generator + + +class ExternalFileSocket: + """ + Socket for managing/querying external files + """ + + def __init__(self, root_socket: SQLAlchemySocket): + self.root_socket = root_socket + self._logger = logging.getLogger(__name__) + self._s3_config = root_socket.qcf_config.s3 + + if not self._s3_config.enabled: + self._logger.info("S3 service for external files is not configured") + return + elif _boto3_spec is None: + raise RuntimeError("boto3 package is required for S3 support") + + self._s3_client = boto3.client( + "s3", + endpoint_url=self._s3_config.endpoint_url, + aws_access_key_id=self._s3_config.access_key_id, + aws_secret_access_key=self._s3_config.secret_access_key, + verify=self._s3_config.verify, + ) + + # may raise an exception (bad/missing credentials, etc) + server_bucket_info = self._s3_client.list_buckets() + server_buckets = {k["Name"] for k in server_bucket_info["Buckets"]} + self._logger.info(f"Found {len(server_buckets)} buckets on the S3 server/account") + + # Make sure the buckets we use exist + self._bucket_map = self._s3_config.bucket_map + if self._bucket_map.dataset_attachment not in server_buckets: + raise RuntimeError(f"Bucket {self._bucket_map.dataset_attachment} (for dataset attachments)") + + def _lookup_bucket(self, file_type: Union[ExternalFileTypeEnum, str]) -> str: + # Can sometimes be a string from sqlalchemy (if set because of polymorphic identity) + if isinstance(file_type, str): + return getattr(self._bucket_map, file_type) + elif isinstance(file_type, ExternalFileTypeEnum): + return getattr(self._bucket_map, file_type.value) + else: + raise ValueError("Unknown parameter type for lookup_bucket: ", type(file_type)) + + def add_data( + self, + file_data: BinaryIO, + file_orm: ExternalFileORM, + *, + job_progress: Optional[JobProgress] = None, + session: Optional[Session] = None, + ) -> int: + """ + Add raw file data to the database. + + This will fill in the appropriate fields on the given ORM object. + The file_name and file_type must be filled in already. + + In this function, the `file_orm` will be added to the session and the session will be flushed. + This is done at the beginning to show that the file is "processing", and at the end when the + addition is completed. + + Parameters + ---------- + file_data + Binary data to be read from + file_orm + Existing ORM object that will be filled in with metadata + job_progress + Object used to track progress if this function is being run in a background job + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + ID of the external file (which is also set in the given ORM object) + """ + + bucket = self._lookup_bucket(file_orm.file_type) + + self._logger.info( + f"Uploading data to S3 bucket {bucket}. file_name={file_orm.file_name} type={file_orm.file_type}" + ) + object_key = str(uuid.uuid4()) + sha256 = hashlib.sha256() + file_size = 0 + + multipart_upload = self._s3_client.create_multipart_upload(Bucket=bucket, Key=object_key) + upload_id = multipart_upload["UploadId"] + parts = [] + part_number = 1 + + with self.root_socket.optional_session(session) as session: + file_orm.status = ExternalFileStatusEnum.processing + file_orm.bucket = bucket + file_orm.object_key = object_key + file_orm.sha256sum = "" + file_orm.file_size = 0 + + session.add(file_orm) + session.flush() + + try: + while chunk := file_data.read(10 * 1024 * 1024): + if job_progress is not None: + job_progress.raise_if_cancelled() + + sha256.update(chunk) + file_size += len(chunk) + + response = self._s3_client.upload_part( + Bucket=bucket, Key=object_key, PartNumber=part_number, UploadId=upload_id, Body=chunk + ) + parts.append({"PartNumber": part_number, "ETag": response["ETag"]}) + part_number += 1 + + self._s3_client.complete_multipart_upload( + Bucket=bucket, Key=object_key, UploadId=upload_id, MultipartUpload={"Parts": parts} + ) + + except Exception as e: + self._s3_client.abort_multipart_upload(Bucket=bucket, Key=object_key, UploadId=upload_id) + raise e + + self._logger.info(f"Uploading data to S3 bucket complete. Finishing writing metadata to db") + file_orm.status = ExternalFileStatusEnum.available + file_orm.sha256sum = sha256.hexdigest() + file_orm.file_size = file_size + session.flush() + + return file_orm.id + + def add_file( + self, + file_path: str, + file_orm: ExternalFileORM, + *, + job_progress: Optional[JobProgress] = None, + session: Optional[Session] = None, + ) -> int: + """ + Add an existing file to the database + + See documentation for :ref:`add_data` for more information. + + Parameters + ---------- + file_path + Path to an existing file to be read from + file_orm + Existing ORM object that will be filled in with metadata + job_progress + Object used to track progress if this function is being run in a background job + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + ID of the external file (which is also set in the given ORM object) + """ + + self._logger.info(f"Uploading {file_path} to S3. File size: {os.path.getsize(file_path)/1048576} MiB") + + with open(file_path, "rb") as f: + return self.add_data(f, file_orm, job_progress=job_progress, session=session) + + def get_metadata( + self, + file_id: int, + *, + session: Optional[Session] = None, + ) -> Dict[str, Any]: + """ + Obtain external file information + + Parameters + ---------- + file_id + ID for the external file + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + List of molecule data (as dictionaries) in the same order as the given ids. + If missing_ok is True, then this list will contain None where the molecule was missing. + """ + + with self.root_socket.optional_session(session, True) as session: + ef = session.get(ExternalFileORM, file_id) + if ef is None: + raise MissingDataError(f"Cannot find external file with id {file_id} in the database") + + return ef.model_dict() + + def delete(self, file_id: int, *, session: Optional[Session] = None): + """ + Deletes an external file from the database and from remote storage + + Parameters + ---------- + file_id + ID of the external file to remove + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + Metadata about what was deleted and any errors that occurred + """ + + with self.root_socket.optional_session(session) as session: + stmt = select(ExternalFileORM).where(ExternalFileORM.id == file_id) + ef_orm = session.execute(stmt).scalar_one_or_none() + + if ef_orm is None: + raise MissingDataError(f"Cannot find external file with id {file_id} in the database") + + bucket = ef_orm.bucket + object_key = ef_orm.object_key + + session.delete(ef_orm) + session.flush() + + # now delete from S3 storage + self._s3_client.delete_object(Bucket=bucket, Key=object_key) + + def get_url(self, file_id: int, *, session: Optional[Session] = None) -> Tuple[str, str]: + """ + Obtain an url that a user can use to download the file directly from the S3 bucket + + Will raise an exception if the file_id does not exist + + Parameters + ---------- + file_id + ID of the external file + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + File name and direct URL to the file + """ + + with self.root_socket.optional_session(session, True) as session: + ef = session.get(ExternalFileORM, file_id) + if ef is None: + raise MissingDataError(f"Cannot find external file with id {file_id} in the database") + + url = self._s3_client.generate_presigned_url( + ClientMethod="get_object", + Params={ + "Bucket": ef.bucket, + "Key": ef.object_key, + "ResponseContentDisposition": f'attachment; filename = "{ef.file_name}"', + }, + HttpMethod="GET", + ExpiresIn=120, + ) + + return ef.file_name, url + + def get_file_streamer( + self, file_id: int, *, session: Optional[Session] = None + ) -> Tuple[str, Callable[[], Generator[bytes, None, None]]]: + """ + Returns a function that streams a file from an S3 bucket. + + Parameters + ---------- + file_id + ID of the external file + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + The recommended filename and a generator function that yields chunks of the file + """ + + with self.root_socket.optional_session(session, True) as session: + ef = session.get(ExternalFileORM, file_id) + if ef is None: + raise MissingDataError(f"Cannot find external file with id {file_id} in the database") + + s3_obj = self._s3_client.get_object(Bucket=ef.bucket, Key=ef.object_key) + + def _generator_func(): + for chunk in s3_obj["Body"].iter_chunks(chunk_size=(8 * 1048576)): + yield chunk + + return ef.file_name, _generator_func diff --git a/qcfractal/qcfractal/components/gridoptimization/test_dataset_client.py b/qcfractal/qcfractal/components/gridoptimization/test_dataset_client.py deleted file mode 100644 index d4b3099a5..000000000 --- a/qcfractal/qcfractal/components/gridoptimization/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.gridoptimization import GridoptimizationDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def gridoptimization_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "gridoptimization", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_gridoptimization_dataset_client_submit(gridoptimization_ds: GridoptimizationDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("go_H3NS_psi4_pbe") - - gridoptimization_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - gridoptimization_ds.add_specification("test_spec", input_spec_1, "test_specification") - - gridoptimization_ds.submit() - assert gridoptimization_ds.status()["test_spec"]["waiting"] == 1 - - gridoptimization_ds.submit() - assert gridoptimization_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in gridoptimization_ds.iterate_records(): - assert r.owner_user == gridoptimization_ds.owner_user - assert r.owner_group == gridoptimization_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - gridoptimization_ds.delete_entries(["test_molecule"]) - assert gridoptimization_ds.status() == {} - - # record still on the server? - r = gridoptimization_ds._client.get_records(record_id) - assert r.owner_user == gridoptimization_ds.owner_user - - # now resubmit - gridoptimization_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - gridoptimization_ds.submit(find_existing=find_existing) - assert gridoptimization_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in gridoptimization_ds.iterate_records(): - assert r.owner_user == gridoptimization_ds.owner_user - assert r.owner_group == gridoptimization_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/internal_jobs/db_models.py b/qcfractal/qcfractal/components/internal_jobs/db_models.py index dc5abbbd0..c9184c7cf 100644 --- a/qcfractal/qcfractal/components/internal_jobs/db_models.py +++ b/qcfractal/qcfractal/components/internal_jobs/db_models.py @@ -31,6 +31,7 @@ class InternalJobORM(BaseORM): runner_uuid = Column(String) progress = Column(Integer, nullable=False, default=0) + progress_description = Column(String, nullable=True) function = Column(String, nullable=False) kwargs = Column(JSON, nullable=False) @@ -55,6 +56,9 @@ class InternalJobORM(BaseORM): # it must be unique. null != null always unique_name = Column(String, nullable=True) + # If this job is part of a serial group (only one may run at a time) + serial_group = Column(String, nullable=True) + __table_args__ = ( Index("ix_internal_jobs_added_date", "added_date", postgresql_using="brin"), Index("ix_internal_jobs_scheduled_date", "scheduled_date", postgresql_using="brin"), @@ -63,6 +67,14 @@ class InternalJobORM(BaseORM): Index("ix_internal_jobs_name", "name"), Index("ix_internal_jobs_user_id", "user_id"), UniqueConstraint("unique_name", name="ux_internal_jobs_unique_name"), + # Enforces only one running per serial group + Index( + "ux_internal_jobs_status_serial_group", + "status", + "serial_group", + unique=True, + postgresql_where=(status == InternalJobStatusEnum.running), + ), ) _qcportal_model_excludes = ["unique_name", "user_id"] diff --git a/qcfractal/qcfractal/components/internal_jobs/socket.py b/qcfractal/qcfractal/components/internal_jobs/socket.py index 7c4e0663f..ac746d650 100644 --- a/qcfractal/qcfractal/components/internal_jobs/socket.py +++ b/qcfractal/qcfractal/components/internal_jobs/socket.py @@ -13,6 +13,7 @@ import psycopg2.extensions from sqlalchemy import select, delete, update, and_, or_ from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.exc import IntegrityError from qcfractal.components.auth.db_models import UserIDMapSubquery from qcfractal.db_socket.helpers import get_query_proj_options @@ -20,7 +21,7 @@ from qcportal.internal_jobs.models import InternalJobStatusEnum, InternalJobQueryFilters from qcportal.utils import now_at_utc from .db_models import InternalJobORM -from .status import JobProgress +from .status import JobProgress, CancelledJobException, JobRunnerStoppingException if TYPE_CHECKING: from qcfractal.db_socket.socket import SQLAlchemySocket @@ -71,6 +72,7 @@ def add( after_function: Optional[str] = None, after_function_kwargs: Optional[Dict[str, Any]] = None, repeat_delay: Optional[int] = None, + serial_group: Optional[str] = None, *, session: Optional[Session] = None, ) -> int: @@ -99,6 +101,8 @@ def add( Arguments to use when calling `after_function` repeat_delay If set, will submit a new, identical job to be run repeat_delay seconds after this one finishes + serial_group + Only one job within this group may be run at once. If None, there is no limit session An existing SQLAlchemy session to use. If None, one will be created. If an existing session is used, it will be flushed (but not committed) before returning from this function. @@ -122,6 +126,7 @@ def add( after_function=after_function, after_function_kwargs=after_function_kwargs, repeat_delay=repeat_delay, + serial_group=serial_group, user_id=user_id, ) stmt = stmt.on_conflict_do_update( @@ -153,6 +158,7 @@ def add( after_function=after_function, after_function_kwargs=after_function_kwargs, repeat_delay=repeat_delay, + serial_group=serial_group, user_id=user_id, ) if unique_name: @@ -287,7 +293,7 @@ def delete_old_internal_jobs(self, session: Session) -> None: if self._internal_job_keep <= 0: return - before = now_at_utc() - timedelta(days=self._internal_job_keep) + before = now_at_utc() - timedelta(seconds=self._internal_job_keep) stmt = delete(InternalJobORM) stmt = stmt.where( @@ -327,6 +333,9 @@ def _run_single(self, session: Session, job_orm: InternalJobORM, logger, job_pro Runs a single job """ + # For logging (ORM may end up detached or somthing) + job_id = job_orm.id + try: func_attr = attrgetter(job_orm.function) @@ -345,28 +354,55 @@ def _run_single(self, session: Session, job_orm: InternalJobORM, logger, job_pro if "session" in func_params: add_kwargs["session"] = session + # Run the desired function + # Raises an exception if cancelled result = func(**job_orm.kwargs, **add_kwargs) - # Mark complete, unless this was cancelled - if not job_progress.cancelled(): - job_orm.status = InternalJobStatusEnum.complete - job_orm.progress = 100 + job_orm.status = InternalJobStatusEnum.complete + job_orm.progress = 100 + job_orm.progress_description = "Complete" + job_orm.result = result - except Exception: + # Job itself is being cancelled + except CancelledJobException: session.rollback() - result = traceback.format_exc() - logger.error(f"Job {job_orm.id} failed with exception:\n{result}") + logger.info(f"Job {job_id} was cancelled") + job_orm.status = InternalJobStatusEnum.cancelled + job_orm.result = None + # Runner is stopping down + except JobRunnerStoppingException: + session.rollback() + logger.info(f"Job {job_id} was running, but runner is stopping") + + # Basically reset everything + job_orm.status = InternalJobStatusEnum.waiting + job_orm.progress = 0 + job_orm.progress_description = None + job_orm.started_date = None + job_orm.last_updated = None + job_orm.result = None + job_orm.runner_uuid = None + + # Function itself had an error + except: + session.rollback() + job_orm.result = traceback.format_exc() + logger.error(f"Job {job_id} failed with exception:\n{job_orm.result}") job_orm.status = InternalJobStatusEnum.error - if not job_progress.deleted(): - job_orm.ended_date = now_at_utc() - job_orm.last_updated = job_orm.ended_date - job_orm.result = result + if job_progress.deleted: + # Row does not exist anymore + session.expunge(job_orm) + else: + # If status is waiting, that means the runner itself is stopping or something + if job_orm.status != InternalJobStatusEnum.waiting: + job_orm.ended_date = now_at_utc() + job_orm.last_updated = job_orm.ended_date - # Clear the unique name so we can add another one if needed - has_unique_name = job_orm.unique_name is not None - job_orm.unique_name = None + # Clear the unique name so we can add another one if needed + has_unique_name = job_orm.unique_name is not None + job_orm.unique_name = None # Flush but don't commit. This will prevent marking the task as finished # before the after_func has been run, but allow new ones to be added @@ -375,14 +411,21 @@ def _run_single(self, session: Session, job_orm: InternalJobORM, logger, job_pro # Run the function specified to be run after if job_orm.status == InternalJobStatusEnum.complete and job_orm.after_function is not None: - after_func_attr = attrgetter(job_orm.after_function) - after_func = after_func_attr(self.root_socket) + try: + after_func_attr = attrgetter(job_orm.after_function) + after_func = after_func_attr(self.root_socket) - after_func_params = inspect.signature(after_func).parameters - add_after_kwargs = {} - if "session" in after_func_params: - add_after_kwargs["session"] = session - after_func(**job_orm.after_function_kwargs, **add_after_kwargs) + after_func_params = inspect.signature(after_func).parameters + add_after_kwargs = {} + if "session" in after_func_params: + add_after_kwargs["session"] = session + after_func(**job_orm.after_function_kwargs, **add_after_kwargs) + except Exception: + # Don't rollback? not sure what to do here + result = traceback.format_exc() + logger.error(f"Job {job_orm.id} failed with exception:\n{result}") + + job_orm.status = InternalJobStatusEnum.error if job_orm.status == InternalJobStatusEnum.complete and job_orm.repeat_delay is not None: self.add( @@ -397,7 +440,8 @@ def _run_single(self, session: Session, job_orm: InternalJobORM, logger, job_pro repeat_delay=job_orm.repeat_delay, session=session, ) - session.commit() + + session.commit() @staticmethod def _wait_for_job(session: Session, logger, conn, end_event): @@ -405,8 +449,16 @@ def _wait_for_job(session: Session, logger, conn, end_event): Blocks until a job is possibly available to run """ + serial_cte = select(InternalJobORM.serial_group).distinct() + serial_cte = serial_cte.where(InternalJobORM.status == InternalJobStatusEnum.running) + serial_cte = serial_cte.where(InternalJobORM.serial_group.is_not(None)) + serial_cte = serial_cte.cte() + next_job_stmt = select(InternalJobORM.scheduled_date) next_job_stmt = next_job_stmt.where(InternalJobORM.status == InternalJobStatusEnum.waiting) + next_job_stmt = next_job_stmt.where( + or_(InternalJobORM.serial_group.is_(None), InternalJobORM.serial_group.not_in(select(serial_cte))) + ) next_job_stmt = next_job_stmt.order_by(InternalJobORM.scheduled_date.asc()) # Skip any that are being claimed for running right now @@ -508,6 +560,11 @@ def run_loop(self, end_event): conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) # Prepare a statement for finding jobs. Filters will be added in the loop + serial_cte = select(InternalJobORM.serial_group).distinct() + serial_cte = serial_cte.where(InternalJobORM.status == InternalJobStatusEnum.running) + serial_cte = serial_cte.where(InternalJobORM.serial_group.is_not(None)) + serial_cte = serial_cte.cte() + stmt = select(InternalJobORM) stmt = stmt.order_by(InternalJobORM.scheduled_date.asc()).limit(1) stmt = stmt.with_for_update(skip_locked=True) @@ -520,10 +577,20 @@ def run_loop(self, end_event): logger.debug("checking for jobs") # Pick up anything waiting, or anything that hasn't been updated in a while (12 update periods) + now = now_at_utc() dead = now - timedelta(seconds=(self._update_frequency * 12)) logger.debug(f"checking for jobs before date {now}") - cond1 = and_(InternalJobORM.status == InternalJobStatusEnum.waiting, InternalJobORM.scheduled_date <= now) + + # Job is waiting, schedule to be run now or in the past, + # Serial group does not have any running or is not set + cond1 = and_( + InternalJobORM.status == InternalJobStatusEnum.waiting, + InternalJobORM.scheduled_date <= now, + or_(InternalJobORM.serial_group.is_(None), InternalJobORM.serial_group.not_in(select(serial_cte))), + ) + + # Job is running but runner is determined to be dead. Serial group doesn't matter cond2 = and_(InternalJobORM.status == InternalJobStatusEnum.running, InternalJobORM.last_updated < dead) stmt_now = stmt.where(or_(cond1, cond2)) @@ -552,8 +619,21 @@ def run_loop(self, end_event): job_orm.runner_uuid = runner_uuid job_orm.status = InternalJobStatusEnum.running - # Releases the row-level lock (from the with_for_update() in the original query) - session_main.commit() + # For logging below - the object might be in an odd state after the exception, where accessing + # it results in further exceptions + serial_group = job_orm.serial_group + + # Violation of the unique constraint may occur - two runners attempting to take + # different jobs of the same serial group at the same time + try: + # Releases the row-level lock (from the with_for_update() in the original query) + session_main.commit() + except IntegrityError: + logger.info( + f"Attempting to run job from serial group '{serial_group}, but seems like another runner got to another job from the same group first" + ) + session_main.rollback() + continue job_progress = JobProgress(job_orm.id, runner_uuid, session_status, self._update_frequency, end_event) self._run_single(session_main, job_orm, logger, job_progress=job_progress) diff --git a/qcfractal/qcfractal/components/internal_jobs/status.py b/qcfractal/qcfractal/components/internal_jobs/status.py index dec8a261d..159f73d47 100644 --- a/qcfractal/qcfractal/components/internal_jobs/status.py +++ b/qcfractal/qcfractal/components/internal_jobs/status.py @@ -2,6 +2,7 @@ import threading import weakref +from typing import Optional from sqlalchemy import update from sqlalchemy.orm import Session @@ -11,6 +12,14 @@ from qcportal.utils import now_at_utc +class CancelledJobException(Exception): + pass + + +class JobRunnerStoppingException(Exception): + pass + + class JobProgress: """ Functor for updating progress and cancelling internal jobs @@ -26,7 +35,10 @@ def __init__(self, job_id: int, runner_uuid: str, session: Session, update_frequ .returning(InternalJobORM.status, InternalJobORM.runner_uuid) ) self._progress = 0 + self._description = None + self._cancelled = False + self._runner_ending = False self._deleted = False self._end_event = end_event @@ -45,7 +57,9 @@ def __init__(self, job_id: int, runner_uuid: str, session: Session, update_frequ def _update_thread(self, session: Session, end_thread: threading.Event): while True: # Update progress - stmt = self._stmt.values(progress=self._progress, last_updated=now_at_utc()) + stmt = self._stmt.values( + progress=self._progress, progress_description=self._description, last_updated=now_at_utc() + ) ret = session.execute(stmt).one_or_none() session.commit() @@ -60,9 +74,9 @@ def _update_thread(self, session: Session, end_thread: threading.Event): # Job was stolen from us? self._cancelled = True - # Are we completely ending? + # Are we ending/cancelling because the runner is stopping/closing? if self._end_event.is_set(): - self._cancelled = True + self._runner_ending = True cancel = end_thread.wait(self._update_frequency) if cancel is True: @@ -80,11 +94,34 @@ def _stop_thread(cancel_event: threading.Event, thread: threading.Thread): def stop(self): self._finalizer() - def update_progress(self, progress: int): + def update_progress(self, progress: int, description: Optional[str] = None): self._progress = progress + self._description = description + @property def cancelled(self) -> bool: - return self._cancelled + return self._cancelled or self._deleted + + @property + def runner_ending(self) -> bool: + return self._runner_ending + @property def deleted(self) -> bool: return self._deleted + + def raise_if_cancelled(self): + if self._cancelled or self._deleted: + raise CancelledJobException("Job was cancelled or deleted") + + if self._runner_ending: + raise JobRunnerStoppingException("Job runner is stopping/ending") + + +def raise_if_cancelled(job_progress: Optional[JobProgress]): + """ + Raises a CancelledJobExcepion if job_progress exists and is in a cancelled or deleted state + """ + + if job_progress is not None: + job_progress.raise_if_cancelled() diff --git a/qcfractal/qcfractal/components/internal_jobs/test_client.py b/qcfractal/qcfractal/components/internal_jobs/test_client.py index f23aef068..db9180598 100644 --- a/qcfractal/qcfractal/components/internal_jobs/test_client.py +++ b/qcfractal/qcfractal/components/internal_jobs/test_client.py @@ -22,11 +22,10 @@ def dummmy_internal_job(self, iterations: int, session, job_progress): for i in range(iterations): time.sleep(1.0) - job_progress.update_progress(100 * ((i + 1) / iterations)) - print("Dummy internal job counter ", i) + job_progress.update_progress(100 * ((i + 1) / iterations), f"Interation {i} of {iterations}") + # print("Dummy internal job counter ", i) - if job_progress.cancelled(): - return "Internal job cancelled" + job_progress.raise_if_cancelled() return "Internal job finished" @@ -36,8 +35,8 @@ def dummmy_internal_job_error(self, session, job_progress): raise RuntimeError("Expected error") -setattr(InternalJobSocket, "dummy_job", dummmy_internal_job) -setattr(InternalJobSocket, "dummy_job_error", dummmy_internal_job_error) +setattr(InternalJobSocket, "client_dummy_job", dummmy_internal_job) +setattr(InternalJobSocket, "client_dummy_job_error", dummmy_internal_job_error) def test_internal_jobs_client_error(snowflake: QCATestingSnowflake): @@ -45,7 +44,7 @@ def test_internal_jobs_client_error(snowflake: QCATestingSnowflake): snowflake_client = snowflake.client() id_1 = storage_socket.internal_jobs.add( - "dummy_job_error", now_at_utc(), "internal_jobs.dummy_job_error", {}, None, unique_name=False + "client_dummy_job_error", now_at_utc(), "internal_jobs.client_dummy_job_error", {}, None, unique_name=False ) # Faster updates for testing @@ -73,7 +72,7 @@ def test_internal_jobs_client_cancel_waiting(snowflake: QCATestingSnowflake): snowflake_client = snowflake.client() id_1 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False + "client_dummy_job", now_at_utc(), "internal_jobs.client.dummy_job", {"iterations": 10}, None, unique_name=False ) snowflake_client.cancel_internal_job(id_1) @@ -90,7 +89,7 @@ def test_internal_jobs_client_cancel_running(snowflake: QCATestingSnowflake): snowflake_client = snowflake.client() id_1 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False + "client_dummy_job", now_at_utc(), "internal_jobs.client_dummy_job", {"iterations": 10}, None, unique_name=False ) # Faster updates for testing @@ -112,7 +111,7 @@ def test_internal_jobs_client_cancel_running(snowflake: QCATestingSnowflake): job_1 = snowflake_client.get_internal_job(id_1) assert job_1.status == InternalJobStatusEnum.cancelled assert job_1.progress < 70 - assert job_1.result == "Internal job cancelled" + assert job_1.result is None finally: end_event.set() @@ -124,7 +123,7 @@ def test_internal_jobs_client_delete_waiting(snowflake: QCATestingSnowflake): snowflake_client = snowflake.client() id_1 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False + "client_dummy_job", now_at_utc(), "internal_jobs.client_dummy_job", {"iterations": 10}, None, unique_name=False ) snowflake_client.delete_internal_job(id_1) @@ -138,7 +137,7 @@ def test_internal_jobs_client_delete_running(snowflake: QCATestingSnowflake): snowflake_client = snowflake.client() id_1 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False + "client_dummy_job", now_at_utc(), "internal_jobs.client_dummy_job", {"iterations": 10}, None, unique_name=False ) # Faster updates for testing @@ -173,7 +172,12 @@ def test_internal_jobs_client_query(secure_snowflake: QCATestingSnowflake): time_0 = now_at_utc() id_1 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 1}, read_id, unique_name=False + "client_dummy_job", + now_at_utc(), + "internal_jobs.client_dummy_job", + {"iterations": 1}, + read_id, + unique_name=False, ) time_1 = now_at_utc() @@ -196,7 +200,7 @@ def test_internal_jobs_client_query(secure_snowflake: QCATestingSnowflake): # Add one that will be waiting id_2 = storage_socket.internal_jobs.add( - "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 1}, None, unique_name=False + "client_dummy_job", now_at_utc(), "internal_jobs.client_dummy_job", {"iterations": 1}, None, unique_name=False ) time_3 = now_at_utc() @@ -217,31 +221,31 @@ def test_internal_jobs_client_query(secure_snowflake: QCATestingSnowflake): assert len(r) == 1 assert r[0].id == id_1 - result = client.query_internal_jobs(name="dummy_job", status=["complete"]) + result = client.query_internal_jobs(name="client_dummy_job", status=["complete"]) r = list(result) assert len(r) == 1 assert r[0].id == id_1 - result = client.query_internal_jobs(name="dummy_job", status="waiting") + result = client.query_internal_jobs(name="client_dummy_job", status="waiting") r = list(result) assert len(r) == 1 assert r[0].id == id_2 - result = client.query_internal_jobs(name="dummy_job", added_after=time_0) + result = client.query_internal_jobs(name="client_dummy_job", added_after=time_0) r = list(result) assert len(r) == 2 assert {r[0].id, r[1].id} == {id_1, id_2} - result = client.query_internal_jobs(name="dummy_job", added_after=time_1, added_before=time_3) + result = client.query_internal_jobs(name="client_dummy_job", added_after=time_1, added_before=time_3) r = list(result) assert len(r) == 1 assert r[0].id == id_2 - result = client.query_internal_jobs(name="dummy_job", last_updated_after=time_2) + result = client.query_internal_jobs(name="client_dummy_job", last_updated_after=time_2) r = list(result) assert len(r) == 0 - result = client.query_internal_jobs(name="dummy_job", last_updated_before=time_2) + result = client.query_internal_jobs(name="client_dummy_job", last_updated_before=time_2) r = list(result) assert len(r) == 1 assert r[0].id == id_1 diff --git a/qcfractal/qcfractal/components/internal_jobs/test_socket.py b/qcfractal/qcfractal/components/internal_jobs/test_socket.py index 8b39d79c4..ee2879285 100644 --- a/qcfractal/qcfractal/components/internal_jobs/test_socket.py +++ b/qcfractal/qcfractal/components/internal_jobs/test_socket.py @@ -3,6 +3,7 @@ import threading import time import uuid +from datetime import timedelta from typing import TYPE_CHECKING import pytest @@ -26,8 +27,7 @@ def dummy_internal_job(self, iterations: int, session, job_progress): job_progress.update_progress(100 * ((i + 1) / iterations)) # print("Dummy internal job counter ", i) - if job_progress.cancelled(): - return "Internal job cancelled" + job_progress.raise_if_cancelled() return "Internal job finished" @@ -116,7 +116,51 @@ def test_internal_jobs_socket_run(storage_socket: SQLAlchemySocket, session: Ses th.join() -def test_internal_jobs_socket_recover(storage_socket: SQLAlchemySocket, session: Session): +@pytest.mark.parametrize("job_func", ("internal_jobs.dummy_job", "internal_jobs.dummy_job_2")) +def test_internal_jobs_socket_run_serial(storage_socket: SQLAlchemySocket, session: Session, job_func: str): + id_1 = storage_socket.internal_jobs.add( + "dummy_job", now_at_utc(), job_func, {"iterations": 10}, None, unique_name=False, serial_group="test" + ) + id_2 = storage_socket.internal_jobs.add( + "dummy_job", now_at_utc(), job_func, {"iterations": 10}, None, unique_name=False, serial_group="test" + ) + id_3 = storage_socket.internal_jobs.add( + "dummy_job", + now_at_utc(), + job_func, + {"iterations": 10}, + None, + unique_name=False, + ) + + # Faster updates for testing + storage_socket.internal_jobs._update_frequency = 1 + + end_event = threading.Event() + th1 = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,)) + th2 = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,)) + th3 = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,)) + th1.start() + th2.start() + th3.start() + time.sleep(8) + + try: + job_1 = session.get(InternalJobORM, id_1) + job_2 = session.get(InternalJobORM, id_2) + job_3 = session.get(InternalJobORM, id_3) + assert job_1.status == InternalJobStatusEnum.running + assert job_2.status == InternalJobStatusEnum.waiting + assert job_3.status == InternalJobStatusEnum.running + + finally: + end_event.set() + th1.join() + th2.join() + th3.join() + + +def test_internal_jobs_socket_runnerstop(storage_socket: SQLAlchemySocket, session: Session): id_1 = storage_socket.internal_jobs.add( "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 10}, None, unique_name=False ) @@ -135,8 +179,13 @@ def test_internal_jobs_socket_recover(storage_socket: SQLAlchemySocket, session: assert not th.is_alive() job_1 = session.get(InternalJobORM, id_1) - assert job_1.status == InternalJobStatusEnum.running - assert job_1.progress > 10 + assert job_1.status == InternalJobStatusEnum.waiting + assert job_1.progress == 0 + assert job_1.started_date is None + assert job_1.last_updated is None + assert job_1.runner_uuid is None + + return old_uuid = job_1.runner_uuid # Change uuid @@ -159,3 +208,43 @@ def test_internal_jobs_socket_recover(storage_socket: SQLAlchemySocket, session: finally: end_event.set() th.join() + + +def test_internal_jobs_socket_recover(storage_socket: SQLAlchemySocket, session: Session): + id_1 = storage_socket.internal_jobs.add( + "dummy_job", now_at_utc(), "internal_jobs.dummy_job", {"iterations": 5}, None, unique_name=False + ) + + # Faster updates for testing + storage_socket.internal_jobs._update_frequency = 1 + + # Manually make it seem like it's running + old_uuid = str(uuid.uuid4()) + job_1 = session.get(InternalJobORM, id_1) + job_1.status = InternalJobStatusEnum.running + job_1.progress = 10 + job_1.last_updated = now_at_utc() - timedelta(seconds=60) + job_1.runner_uuid = old_uuid + session.commit() + + session.expire(job_1) + job_1 = session.get(InternalJobORM, id_1) + assert job_1.status == InternalJobStatusEnum.running + assert job_1.runner_uuid == old_uuid + + # Job is now running but orphaned. Should be picked up next time + end_event = threading.Event() + th = threading.Thread(target=storage_socket.internal_jobs.run_loop, args=(end_event,)) + th.start() + time.sleep(10) + + try: + session.expire(job_1) + job_1 = session.get(InternalJobORM, id_1) + assert job_1.status == InternalJobStatusEnum.complete + assert job_1.runner_uuid != old_uuid + assert job_1.progress == 100 + assert job_1.result == "Internal job finished" + finally: + end_event.set() + th.join() diff --git a/qcfractal/qcfractal/components/manybody/test_dataset_client.py b/qcfractal/qcfractal/components/manybody/test_dataset_client.py deleted file mode 100644 index 7adb43156..000000000 --- a/qcfractal/qcfractal/components/manybody/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.manybody import ManybodyDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def manybody_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "manybody", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_manybody_dataset_client_submit(manybody_ds: ManybodyDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("mb_cp_he4_psi4_mp2") - - manybody_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - manybody_ds.add_specification("test_spec", input_spec_1, "test_specification") - - manybody_ds.submit() - assert manybody_ds.status()["test_spec"]["waiting"] == 1 - - manybody_ds.submit() - assert manybody_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in manybody_ds.iterate_records(): - assert r.owner_user == manybody_ds.owner_user - assert r.owner_group == manybody_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - manybody_ds.delete_entries(["test_molecule"]) - assert manybody_ds.status() == {} - - # record still on the server? - r = manybody_ds._client.get_records(record_id) - assert r.owner_user == manybody_ds.owner_user - - # now resubmit - manybody_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - manybody_ds.submit(find_existing=find_existing) - assert manybody_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in manybody_ds.iterate_records(): - assert r.owner_user == manybody_ds.owner_user - assert r.owner_group == manybody_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/neb/test_dataset_client.py b/qcfractal/qcfractal/components/neb/test_dataset_client.py deleted file mode 100644 index 242e0f145..000000000 --- a/qcfractal/qcfractal/components/neb/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.neb import NEBDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def neb_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "neb", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_neb_dataset_client_submit(neb_ds: NEBDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("neb_HCN_psi4_pbe_opt_diff") - - neb_ds.add_entry(name="test_molecule", initial_chain=molecule_1) - neb_ds.add_specification("test_spec", input_spec_1, "test_specification") - - neb_ds.submit() - assert neb_ds.status()["test_spec"]["waiting"] == 1 - - neb_ds.submit() - assert neb_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in neb_ds.iterate_records(): - assert r.owner_user == neb_ds.owner_user - assert r.owner_group == neb_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - neb_ds.delete_entries(["test_molecule"]) - assert neb_ds.status() == {} - - # record still on the server? - r = neb_ds._client.get_records(record_id) - assert r.owner_user == neb_ds.owner_user - - # now resubmit - neb_ds.add_entry(name="test_molecule", initial_chain=molecule_1) - neb_ds.submit(find_existing=find_existing) - assert neb_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in neb_ds.iterate_records(): - assert r.owner_user == neb_ds.owner_user - assert r.owner_group == neb_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/optimization/test_dataset_client.py b/qcfractal/qcfractal/components/optimization/test_dataset_client.py deleted file mode 100644 index 33e19eccd..000000000 --- a/qcfractal/qcfractal/components/optimization/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.optimization import OptimizationDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def optimization_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "optimization", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_optimization_dataset_client_submit(optimization_ds: OptimizationDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("opt_psi4_benzene") - - optimization_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - optimization_ds.add_specification("test_spec", input_spec_1, "test_specification") - - optimization_ds.submit() - assert optimization_ds.status()["test_spec"]["waiting"] == 1 - - optimization_ds.submit() - assert optimization_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in optimization_ds.iterate_records(): - assert r.owner_user == optimization_ds.owner_user - assert r.owner_group == optimization_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - optimization_ds.delete_entries(["test_molecule"]) - assert optimization_ds.status() == {} - - # record still on the server? - r = optimization_ds._client.get_records(record_id) - assert r.owner_user == optimization_ds.owner_user - - # now resubmit - optimization_ds.add_entry(name="test_molecule", initial_molecule=molecule_1) - optimization_ds.submit(find_existing=find_existing) - assert optimization_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in optimization_ds.iterate_records(): - assert r.owner_user == optimization_ds.owner_user - assert r.owner_group == optimization_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/reaction/test_dataset_client.py b/qcfractal/qcfractal/components/reaction/test_dataset_client.py deleted file mode 100644 index 7efbf3632..000000000 --- a/qcfractal/qcfractal/components/reaction/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.reaction import ReactionDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def reaction_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "reaction", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_reaction_dataset_client_submit(reaction_ds: ReactionDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("rxn_H2O_psi4_mp2_optsp") - - reaction_ds.add_entry(name="test_molecule", stoichiometries=molecule_1) - reaction_ds.add_specification("test_spec", input_spec_1, "test_specification") - - reaction_ds.submit() - assert reaction_ds.status()["test_spec"]["waiting"] == 1 - - reaction_ds.submit() - assert reaction_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in reaction_ds.iterate_records(): - assert r.owner_user == reaction_ds.owner_user - assert r.owner_group == reaction_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - reaction_ds.delete_entries(["test_molecule"]) - assert reaction_ds.status() == {} - - # record still on the server? - r = reaction_ds._client.get_records(record_id) - assert r.owner_user == reaction_ds.owner_user - - # now resubmit - reaction_ds.add_entry(name="test_molecule", stoichiometries=molecule_1) - reaction_ds.submit(find_existing=find_existing) - assert reaction_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in reaction_ds.iterate_records(): - assert r.owner_user == reaction_ds.owner_user - assert r.owner_group == reaction_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/register_all.py b/qcfractal/qcfractal/components/register_all.py index 80ac72dc2..cc24d092c 100644 --- a/qcfractal/qcfractal/components/register_all.py +++ b/qcfractal/qcfractal/components/register_all.py @@ -9,6 +9,7 @@ from .tasks import db_models, routes from .services import db_models from .internal_jobs import db_models, routes +from .external_files import db_models, routes from . import record_db_models, dataset_db_models, record_routes, dataset_routes from .singlepoint import record_db_models, dataset_db_models, routes diff --git a/qcfractal/qcfractal/components/serverinfo/socket.py b/qcfractal/qcfractal/components/serverinfo/socket.py index 60524e121..8c3ce3317 100644 --- a/qcfractal/qcfractal/components/serverinfo/socket.py +++ b/qcfractal/qcfractal/components/serverinfo/socket.py @@ -281,7 +281,7 @@ def delete_old_access_logs(self, session: Session) -> None: if self._access_log_keep <= 0 or not self._access_log_enabled: return - before = now_at_utc() - timedelta(days=self._access_log_keep) + before = now_at_utc() - timedelta(seconds=self._access_log_keep) num_deleted = self.delete_access_logs(before, session=session) self._logger.info(f"Deleted {num_deleted} access logs before {before}") diff --git a/qcfractal/qcfractal/components/singlepoint/test_dataset_client.py b/qcfractal/qcfractal/components/singlepoint/test_dataset_client.py deleted file mode 100644 index 84ae31f20..000000000 --- a/qcfractal/qcfractal/components/singlepoint/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.record_models import PriorityEnum -from qcportal.singlepoint import SinglepointDataset -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def singlepoint_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "singlepoint", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_singlepoint_dataset_client_submit(singlepoint_ds: SinglepointDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("sp_psi4_benzene_energy_1") - - singlepoint_ds.add_entry(name="test_molecule", molecule=molecule_1) - singlepoint_ds.add_specification("test_spec", input_spec_1, "test_specification") - - singlepoint_ds.submit() - assert singlepoint_ds.status()["test_spec"]["waiting"] == 1 - - singlepoint_ds.submit() - assert singlepoint_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in singlepoint_ds.iterate_records(): - assert r.owner_user == singlepoint_ds.owner_user - assert r.owner_group == singlepoint_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - singlepoint_ds.delete_entries(["test_molecule"]) - assert singlepoint_ds.status() == {} - - # record still on the server? - r = singlepoint_ds._client.get_records(record_id) - assert r.owner_user == singlepoint_ds.owner_user - - # now resubmit - singlepoint_ds.add_entry(name="test_molecule", molecule=molecule_1) - singlepoint_ds.submit(find_existing=find_existing) - assert singlepoint_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in singlepoint_ds.iterate_records(): - assert r.owner_user == singlepoint_ds.owner_user - assert r.owner_group == singlepoint_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/components/torsiondrive/test_dataset_client.py b/qcfractal/qcfractal/components/torsiondrive/test_dataset_client.py deleted file mode 100644 index f9f159e78..000000000 --- a/qcfractal/qcfractal/components/torsiondrive/test_dataset_client.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import pytest - -from qcportal.torsiondrive import TorsiondriveDataset -from qcportal.record_models import PriorityEnum -from .testing_helpers import load_test_data - -if TYPE_CHECKING: - from qcportal import PortalClient - - -@pytest.fixture(scope="function") -def torsiondrive_ds(submitter_client: PortalClient): - ds = submitter_client.add_dataset( - "torsiondrive", - "Test dataset", - "Test Description", - "a Tagline", - ["tag1", "tag2"], - "new_group", - {"prov_key_1": "prov_value_1"}, - True, - "def_tag", - PriorityEnum.low, - {"meta_key_1": "meta_value_1"}, - "group1", - ) - - assert ds.owner_user is not None - assert ds.owner_user == ds._client.username - assert ds.owner_group == "group1" - - yield ds - - -@pytest.mark.parametrize("find_existing", [True, False]) -def test_torsiondrive_dataset_client_submit(torsiondrive_ds: TorsiondriveDataset, find_existing: bool): - input_spec_1, molecule_1, _ = load_test_data("td_H2O2_mopac_pm6") - - torsiondrive_ds.add_entry(name="test_molecule", initial_molecules=molecule_1) - torsiondrive_ds.add_specification("test_spec", input_spec_1, "test_specification") - - torsiondrive_ds.submit() - assert torsiondrive_ds.status()["test_spec"]["waiting"] == 1 - - torsiondrive_ds.submit() - assert torsiondrive_ds.status()["test_spec"]["waiting"] == 1 - - # Should only be one record - record_id = 0 - for e, s, r in torsiondrive_ds.iterate_records(): - assert r.owner_user == torsiondrive_ds.owner_user - assert r.owner_group == torsiondrive_ds.owner_group - record_id = r.id - - # delete & re-add entry, then resubmit - torsiondrive_ds.delete_entries(["test_molecule"]) - assert torsiondrive_ds.status() == {} - - # record still on the server? - r = torsiondrive_ds._client.get_records(record_id) - assert r.owner_user == torsiondrive_ds.owner_user - - # now resubmit - torsiondrive_ds.add_entry(name="test_molecule", initial_molecules=molecule_1) - torsiondrive_ds.submit(find_existing=find_existing) - assert torsiondrive_ds.status()["test_spec"]["waiting"] == 1 - - for e, s, r in torsiondrive_ds.iterate_records(): - assert r.owner_user == torsiondrive_ds.owner_user - assert r.owner_group == torsiondrive_ds.owner_group - - if find_existing: - assert r.id == record_id - else: - assert r.id != record_id diff --git a/qcfractal/qcfractal/config.py b/qcfractal/qcfractal/config.py index 5fc57adcf..30672f696 100644 --- a/qcfractal/qcfractal/config.py +++ b/qcfractal/qcfractal/config.py @@ -7,6 +7,7 @@ import logging import os import secrets +import tempfile from typing import Optional, Dict, Union, Any import yaml @@ -21,16 +22,7 @@ from sqlalchemy.engine.url import URL, make_url from qcfractal.port_util import find_open_port -from qcportal.utils import duration_to_seconds - - -def update_nested_dict(d, u): - for k, v in u.items(): - if isinstance(v, dict): - d[k] = update_nested_dict(d.get(k, {}), v) - else: - d[k] = v - return d +from qcportal.utils import duration_to_seconds, update_nested_dict def _make_abs_path(path: Optional[str], base_folder: str, default_filename: Optional[str]) -> Optional[str]: @@ -324,6 +316,37 @@ class Config(ConfigCommon): env_prefix = "QCF_API_" +class S3BucketMap(ConfigBase): + dataset_attachment: str = Field("dataset_attachment", description="Bucket to hold dataset views") + + +class S3Config(ConfigBase): + """ + Settings for using external files with S3 + """ + + enabled: bool = False + verify: bool = True + passthrough: bool = False + endpoint_url: Optional[str] = Field(None, description="S3 endpoint URL") + access_key_id: Optional[str] = Field(None, description="AWS/S3 access key") + secret_access_key: Optional[str] = Field(None, description="AWS/S3 secret key") + + bucket_map: S3BucketMap = Field(S3BucketMap(), description="Configuration for where to store various files") + + class Config(ConfigCommon): + env_prefix = "QCF_S3_" + + @root_validator() + def _check_enabled(cls, values): + if values.get("enabled", False) is True: + for key in ["endpoint_url", "access_key_id", "secret_access_key"]: + if values.get(key, None) is None: + raise ValueError(f"S3 enabled but {key} not set") + + return values + + class FractalConfig(ConfigBase): """ Fractal Server settings @@ -334,6 +357,11 @@ class FractalConfig(ConfigBase): description="The base directory to use as the default for some options (logs, etc). Default is the location of the config file.", ) + temporary_dir: Optional[str] = Field( + None, + description="Temporary directory to use for things such as view creation. If None, uses system default. This may require a lot of space!", + ) + # Info for the REST interface name: str = Field("QCFractal Server", description="The QCFractal server name") @@ -380,7 +408,7 @@ class FractalConfig(ConfigBase): # Access logging log_access: bool = Field(False, description="Store API access in the database") access_log_keep: int = Field( - 0, description="How far back to keep access logs (in seconds or a string). 0 means keep all" + 0, description="How far back to keep access logs (in days or as a duration string). 0 means keep all" ) # maxmind_account_id: Optional[int] = Field(None, description="Account ID for MaxMind GeoIP2 service") @@ -403,7 +431,7 @@ class FractalConfig(ConfigBase): 1, description="Number of processes for processing internal jobs and async requests" ) internal_job_keep: int = Field( - 0, description="Number of days of finished internal job logs to keep. 0 means keep all" + 0, description="How far back to keep finished internal jobs (in days or as a duration string). 0 means keep all" ) # Homepage settings @@ -413,6 +441,7 @@ class FractalConfig(ConfigBase): # Other settings blocks database: DatabaseConfig = Field(..., description="Configuration of the settings for the database") api: WebAPIConfig = Field(..., description="Configuration of the REST interface") + s3: S3Config = Field(S3Config(), description="Configuration of the S3 file storage (optional)") api_limits: APILimitConfig = Field(..., description="Configuration of the limits to the api") auto_reset: AutoResetConfig = Field(..., description="Configuration for automatic resetting of tasks") @@ -461,10 +490,25 @@ def _check_loglevel(cls, v): raise ValidationError(f"{v} is not a valid loglevel. Must be DEBUG, INFO, WARNING, ERROR, or CRITICAL") return v - @validator("service_frequency", "heartbeat_frequency", "access_log_keep", pre=True) + @validator("service_frequency", "heartbeat_frequency", pre=True) def _convert_durations(cls, v): return duration_to_seconds(v) + @validator("access_log_keep", "internal_job_keep", pre=True) + def _convert_durations_days(cls, v): + if isinstance(v, int) or (isinstance(v, str) and v.isdigit()): + return int(v) * 86400 + return duration_to_seconds(v) + + @validator("temporary_dir", pre=True) + def _create_temporary_directory(cls, v, values): + v = _make_abs_path(v, values["base_folder"], tempfile.gettempdir()) + + if v is not None and not os.path.exists(v): + os.makedirs(v) + + return v + class Config(ConfigCommon): env_prefix = "QCF_" diff --git a/qcfractal/qcfractal/db_socket/socket.py b/qcfractal/qcfractal/db_socket/socket.py index 3635439b0..a128f7f01 100644 --- a/qcfractal/qcfractal/db_socket/socket.py +++ b/qcfractal/qcfractal/db_socket/socket.py @@ -100,6 +100,7 @@ def checkout(dbapi_connection, connection_record, connection_proxy): from ..components.managers.socket import ManagerSocket from ..components.tasks.socket import TaskSocket from ..components.services.socket import ServiceSocket + from ..components.external_files import ExternalFileSocket from ..components.record_socket import RecordSocket from ..components.dataset_socket import DatasetSocket @@ -111,6 +112,7 @@ def checkout(dbapi_connection, connection_record, connection_proxy): self.molecules = MoleculeSocket(self) self.datasets = DatasetSocket(self) self.records = RecordSocket(self) + self.external_files = ExternalFileSocket(self) self.tasks = TaskSocket(self) self.services = ServiceSocket(self) self.managers = ManagerSocket(self) diff --git a/qcfractal/qcfractal/flask_app/api_v1/helpers.py b/qcfractal/qcfractal/flask_app/api_v1/helpers.py index 7b983cb5b..0d9d24870 100644 --- a/qcfractal/qcfractal/flask_app/api_v1/helpers.py +++ b/qcfractal/qcfractal/flask_app/api_v1/helpers.py @@ -94,7 +94,10 @@ def wrapper(*args, **kwargs): # Now call the function, and validate the output ret = fn(*args, **kwargs) - # Serialize the output + # Serialize the output it it's not a normal flask response + if isinstance(ret, Response): + return ret + serialized = serialize(ret, accept_type) return Response(serialized, content_type=accept_type) diff --git a/qcfractal/qcfractal/snowflake.py b/qcfractal/qcfractal/snowflake.py index 73b632210..c5a95bf51 100644 --- a/qcfractal/qcfractal/snowflake.py +++ b/qcfractal/qcfractal/snowflake.py @@ -17,7 +17,8 @@ from qcportal import PortalClient from qcportal.record_models import RecordStatusEnum -from .config import FractalConfig, DatabaseConfig, update_nested_dict +from qcportal.utils import update_nested_dict +from .config import FractalConfig, DatabaseConfig from .flask_app.waitress_app import FractalWaitressApp from .job_runner import FractalJobRunner from .port_util import find_open_port diff --git a/qcfractal/qcfractal/test_config.py b/qcfractal/qcfractal/test_config.py index f507e78ab..600525e4c 100644 --- a/qcfractal/qcfractal/test_config.py +++ b/qcfractal/qcfractal/test_config.py @@ -1,4 +1,5 @@ import copy +import os from qcfractal.config import FractalConfig @@ -17,14 +18,16 @@ def test_config_durations_plain(tmp_path): base_config = copy.deepcopy(_base_config) base_config["service_frequency"] = 3600 base_config["heartbeat_frequency"] = 30 - base_config["access_log_keep"] = 100802 + base_config["access_log_keep"] = 31 + base_config["internal_job_keep"] = 7 base_config["api"]["jwt_access_token_expires"] = 7450 base_config["api"]["jwt_refresh_token_expires"] = 637277 cfg = FractalConfig(base_folder=base_folder, **base_config) assert cfg.service_frequency == 3600 assert cfg.heartbeat_frequency == 30 - assert cfg.access_log_keep == 100802 + assert cfg.access_log_keep == 2678400 # interpreted as days + assert cfg.internal_job_keep == 604800 assert cfg.api.jwt_access_token_expires == 7450 assert cfg.api.jwt_refresh_token_expires == 637277 @@ -36,6 +39,7 @@ def test_config_durations_str(tmp_path): base_config["service_frequency"] = "1h" base_config["heartbeat_frequency"] = "30s" base_config["access_log_keep"] = "1d4h2s" + base_config["internal_job_keep"] = "1d4h7s" base_config["api"]["jwt_access_token_expires"] = "2h4m10s" base_config["api"]["jwt_refresh_token_expires"] = "7d9h77s" cfg = FractalConfig(base_folder=base_folder, **base_config) @@ -43,6 +47,7 @@ def test_config_durations_str(tmp_path): assert cfg.service_frequency == 3600 assert cfg.heartbeat_frequency == 30 assert cfg.access_log_keep == 100802 + assert cfg.internal_job_keep == 100807 assert cfg.api.jwt_access_token_expires == 7450 assert cfg.api.jwt_refresh_token_expires == 637277 @@ -54,6 +59,7 @@ def test_config_durations_dhms(tmp_path): base_config["service_frequency"] = "1:00:00" base_config["heartbeat_frequency"] = "30" base_config["access_log_keep"] = "1:04:00:02" + base_config["internal_job_keep"] = "1:04:00:07" base_config["api"]["jwt_access_token_expires"] = "2:04:10" base_config["api"]["jwt_refresh_token_expires"] = "7:09:00:77" cfg = FractalConfig(base_folder=base_folder, **base_config) @@ -61,5 +67,16 @@ def test_config_durations_dhms(tmp_path): assert cfg.service_frequency == 3600 assert cfg.heartbeat_frequency == 30 assert cfg.access_log_keep == 100802 + assert cfg.internal_job_keep == 100807 assert cfg.api.jwt_access_token_expires == 7450 assert cfg.api.jwt_refresh_token_expires == 637277 + + +def test_config_tmpdir_create(tmp_path): + base_folder = str(tmp_path) + base_config = copy.deepcopy(_base_config) + base_config["temporary_dir"] = str(tmp_path / "qcatmpdir") + cfg = FractalConfig(base_folder=base_folder, **base_config) + + assert cfg.temporary_dir == str(tmp_path / "qcatmpdir") + assert os.path.exists(cfg.temporary_dir) diff --git a/qcfractalcompute/qcfractalcompute/config.py b/qcfractalcompute/qcfractalcompute/config.py index 11b27633b..58f6d8a8e 100644 --- a/qcfractalcompute/qcfractalcompute/config.py +++ b/qcfractalcompute/qcfractalcompute/config.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field, validator from typing_extensions import Literal -from qcportal.utils import seconds_to_hms, duration_to_seconds +from qcportal.utils import seconds_to_hms, duration_to_seconds, update_nested_dict def _make_abs_path(path: Optional[str], base_folder: str, default_filename: Optional[str]) -> Optional[str]: @@ -31,15 +31,6 @@ def _make_abs_path(path: Optional[str], base_folder: str, default_filename: Opti return os.path.abspath(path) -def update_nested_dict(d, u): - for k, v in u.items(): - if isinstance(v, dict): - d[k] = update_nested_dict(d.get(k, {}), v) - else: - d[k] = v - return d - - class PackageEnvironmentSettings(BaseModel): """ Environments with installed packages that can be used to run calculations diff --git a/qcfractalcompute/qcfractalcompute/testing_helpers.py b/qcfractalcompute/qcfractalcompute/testing_helpers.py index 64ddf241d..07ff91365 100644 --- a/qcfractalcompute/qcfractalcompute/testing_helpers.py +++ b/qcfractalcompute/qcfractalcompute/testing_helpers.py @@ -12,7 +12,6 @@ from qcfractal.components.optimization.testing_helpers import submit_test_data as submit_opt_test_data from qcfractal.components.singlepoint.testing_helpers import submit_test_data as submit_sp_test_data from qcfractal.config import FractalConfig -from qcfractal.config import update_nested_dict from qcfractal.db_socket import SQLAlchemySocket from qcfractalcompute.apps.models import AppTaskResult from qcfractalcompute.compress import compress_result @@ -20,6 +19,7 @@ from qcfractalcompute.config import FractalComputeConfig, FractalServerSettings, LocalExecutorConfig from qcportal.all_results import AllResultTypes, FailedOperation from qcportal.record_models import PriorityEnum, RecordTask +from qcportal.utils import update_nested_dict failed_op = FailedOperation( input_data=None, diff --git a/qcportal/qcportal/cache.py b/qcportal/qcportal/cache.py index ea2e4d826..bc9b474b6 100644 --- a/qcportal/qcportal/cache.py +++ b/qcportal/qcportal/cache.py @@ -74,11 +74,11 @@ def _create_tables(self): self._conn.execute( """ CREATE TABLE IF NOT EXISTS records ( - id INTEGER NOT NULL PRIMARY KEY, + id INTEGER PRIMARY KEY, status TEXT NOT NULL, modified_on DECIMAL NOT NULL, record BLOB NOT NULL - ) WITHOUT ROWID + ) """ ) @@ -540,21 +540,37 @@ def __init__(self, server_uri: str, cache_dir: Optional[str], max_size: int): self.cache_dir = None + def get_cache_path(self, cache_name: str) -> str: + if not self._is_disk: + raise RuntimeError("Cannot get path to cache for memory-only cache") + + return os.path.join(self.cache_dir, f"{cache_name}.sqlite") + def get_cache_uri(self, cache_name: str) -> str: if self._is_disk: - file_path = os.path.join(self.cache_dir, f"{cache_name}.sqlite") + file_path = self.get_cache_path(cache_name) uri = f"file:{file_path}" else: uri = ":memory:" return uri + def get_dataset_cache_path(self, dataset_id: int) -> str: + return self.get_cache_path(f"dataset_{dataset_id}") + + def get_dataset_cache_uri(self, dataset_id: int) -> str: + return self.get_cache_uri(f"dataset_{dataset_id}") + def get_dataset_cache(self, dataset_id: int, dataset_type: Type[_DATASET_T]) -> DatasetCache: - uri = self.get_cache_uri(f"dataset_{dataset_id}") + uri = self.get_dataset_cache_uri(dataset_id) # If you are asking this for a dataset cache, it should be writable return DatasetCache(uri, False, dataset_type) + @property + def is_disk(self) -> bool: + return self._is_disk + def vacuum(self, cache_name: Optional[str] = None): if self._is_disk: # TODO diff --git a/qcportal/qcportal/client.py b/qcportal/qcportal/client.py index ba450e3c3..e1fe97cee 100644 --- a/qcportal/qcportal/client.py +++ b/qcportal/qcportal/client.py @@ -2,12 +2,14 @@ import logging import math +import os from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, Union, Sequence, Iterable, TypeVar, Type from tabulate import tabulate from qcportal.cache import DatasetCache, read_dataset_metadata +from qcportal.external_files import ExternalFile from qcportal.gridoptimization import ( GridoptimizationKeywords, GridoptimizationAddBody, @@ -325,6 +327,52 @@ def delete_dataset(self, dataset_id: int, delete_records: bool): params = DatasetDeleteParams(delete_records=delete_records) return self.make_request("delete", f"api/v1/datasets/{dataset_id}", None, url_params=params) + ############################################################## + # External files + ############################################################## + def download_external_file(self, file_id: int, destination_path: str, overwrite: bool = False) -> Tuple[int, str]: + """ + Downloads an external file to the given path + + The file size and checksum will be checked against the metadata stored on the server + + Parameters + ---------- + file_id + ID of the file to obtain + destination_path + Full path to the destination file (including filename) + overwrite + If True, allow for overwriting an existing file. If False, and a file already exists at the given + destination path, an exception will be raised. + + Returns + ------- + : + A tuple of file size and sha256 checksum. + + """ + meta_url = f"api/v1/external_files/{file_id}" + download_url = f"api/v1/external_files/{file_id}/download" + + # Check for local file existence before doing any requests + if os.path.exists(destination_path) and not overwrite: + raise RuntimeError(f"File already exists at {destination_path}. To overwrite, use `overwrite=True`") + + # First, get the metadata + file_info = self.make_request("get", meta_url, ExternalFile) + + # Now actually download the file + file_size, file_sha256 = self.download_file(download_url, destination_path, overwrite=overwrite) + + if file_size != file_info.file_size: + raise RuntimeError(f"Inconsistent file size. Expected {file_info.file_size}, got {file_size}") + + if file_sha256 != file_info.sha256sum: + raise RuntimeError(f"Inconsistent file checksum. Expected {file_info.sha256sum}, got {file_sha256}") + + return file_size, file_sha256 + ############################################################## # Molecules ############################################################## @@ -2685,7 +2733,8 @@ def get_internal_job(self, job_id: int) -> InternalJob: Gets information about an internal job on the server """ - return self.make_request("get", f"api/v1/internal_jobs/{job_id}", InternalJob) + ij_dict = self.make_request("get", f"api/v1/internal_jobs/{job_id}", Dict[str, Any]) + return InternalJob(client=self, **ij_dict) def query_internal_jobs( self, diff --git a/qcportal/qcportal/client_base.py b/qcportal/qcportal/client_base.py index f036e4d96..c60725831 100644 --- a/qcportal/qcportal/client_base.py +++ b/qcportal/qcportal/client_base.py @@ -20,7 +20,9 @@ except ImportError: import pydantic import requests +from typing import Tuple import yaml +import hashlib from packaging.version import parse as parse_version from . import __version__ @@ -336,7 +338,9 @@ def _request( try: if not allow_retries: - r = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout) + r = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False) + if r.is_redirect: + raise RuntimeError("Redirection is not allowed") else: current_retries = 0 while True: @@ -428,6 +432,38 @@ def make_request( else: return pydantic.parse_obj_as(response_model, d) + def download_file(self, endpoint: str, destination_path: str, overwrite: bool = False) -> Tuple[int, str]: + + sha256 = hashlib.sha256() + file_size = 0 + + # Remove if overwrite=True. This allows for any processes still using the old file to keep using it + # (at least on linux) + if os.path.exists(destination_path): + if overwrite: + os.remove(destination_path) + else: + raise RuntimeError(f"File already exists at {destination_path}. To overwrite, use `overwrite=True`") + + full_uri = self.address + endpoint + response = self._req_session.get(full_uri, stream=True, allow_redirects=False) + + if response.is_redirect: + # send again, but using a plain requests object + # that way, we don't pass the JWT to someone else + new_location = response.headers["Location"] + response = requests.get(new_location, stream=True, allow_redirects=True) + + response.raise_for_status() + with open(destination_path, "wb") as f: + for chunk in response.iter_content(chunk_size=None): + if chunk: + f.write(chunk) + sha256.update(chunk) + file_size += len(chunk) + + return file_size, sha256.hexdigest() + def ping(self) -> bool: """ Pings the server to see if it is up diff --git a/qcportal/qcportal/dataset_models.py b/qcportal/qcportal/dataset_models.py index 9685b9a6b..62b162efd 100644 --- a/qcportal/qcportal/dataset_models.py +++ b/qcportal/qcportal/dataset_models.py @@ -3,6 +3,7 @@ import math import os from datetime import datetime +from enum import Enum from typing import ( TYPE_CHECKING, Optional, @@ -30,17 +31,31 @@ from tqdm import tqdm from qcportal.base_models import RestModelBase, validate_list_to_single, CommonBulkGetBody -from qcportal.metadata_models import DeleteMetadata -from qcportal.metadata_models import InsertMetadata +from qcportal.internal_jobs import InternalJob, InternalJobStatusEnum +from qcportal.metadata_models import DeleteMetadata, InsertMetadata from qcportal.record_models import PriorityEnum, RecordStatusEnum, BaseRecord from qcportal.utils import make_list, chunk_iterable from qcportal.cache import DatasetCache, read_dataset_metadata, get_records_with_cache +from qcportal.external_files import ExternalFile if TYPE_CHECKING: from qcportal.client import PortalClient from pandas import DataFrame +class DatasetAttachmentType(str, Enum): + """ + The type of attachment a file is for a dataset + """ + + other = "other" + view = "view" + + +class DatasetAttachment(ExternalFile): + attachment_type: DatasetAttachmentType + + class Citation(BaseModel): """A literature citation.""" @@ -117,6 +132,7 @@ class Config: # Fields not always included when fetching the dataset ###################################################### contributed_values_: Optional[Dict[str, ContributedValues]] = Field(None, alias="contributed_values") + attachments_: Optional[List[DatasetAttachment]] = Field(None, alias="attachments") ############################# # Private non-pydantic fields @@ -320,6 +336,225 @@ def submit( "post", f"api/v1/datasets/{self.dataset_type}/{self.id}/submit", Any, body=body_data ) + ######################################### + # Internal jobs + ######################################### + def get_internal_job(self, job_id: int) -> InternalJob: + self.assert_is_not_view() + self.assert_online() + + ij_dict = self._client.make_request("get", f"/api/v1/datasets/{self.id}/internal_jobs/{job_id}", Dict[str, Any]) + refresh_url = f"/api/v1/datasets/{self.id}/internal_jobs/{ij_dict['id']}" + return InternalJob(client=self._client, refresh_url=refresh_url, **ij_dict) + + def list_internal_jobs( + self, status: Optional[Union[InternalJobStatusEnum, Iterable[InternalJobStatusEnum]]] = None + ) -> List[InternalJob]: + self.assert_is_not_view() + self.assert_online() + + url_params = DatasetGetInternalJobParams(status=make_list(status)) + ij_dicts = self._client.make_request( + "get", f"/api/v1/datasets/{self.id}/internal_jobs", List[Dict[str, Any]], url_params=url_params + ) + return [ + InternalJob(client=self._client, refresh_url=f"/api/v1/datasets/{self.id}/internal_jobs/{ij['id']}", **ij) + for ij in ij_dicts + ] + + ######################################### + # Attachments + ######################################### + def fetch_attachments(self): + self.assert_is_not_view() + self.assert_online() + + self.attachments_ = self._client.make_request( + "get", + f"api/v1/datasets/{self.id}/attachments", + Optional[List[DatasetAttachment]], + ) + + @property + def attachments(self) -> List[DatasetAttachment]: + if not self.attachments_: + self.fetch_attachments() + + return self.attachments_ + + def delete_attachment(self, file_id: int): + self.assert_is_not_view() + self.assert_online() + + self._client.make_request( + "delete", + f"api/v1/datasets/{self.id}/attachments/{file_id}", + None, + ) + + self.fetch_attachments() + + ######################################### + # View creation and use + ######################################### + def list_views(self): + return [x for x in self.attachments if x.attachment_type == DatasetAttachmentType.view] + + def download_view( + self, + view_file_id: Optional[int] = None, + destination_path: Optional[str] = None, + overwrite: bool = True, + ): + """ + Downloads a view for this dataset + + If a `view_file_id` is not given, the most recent view will be downloaded. + + If destination path is not given, the file will be placed in the current directory, and the + filename determined by what is stored on the server. + + Parameters + ---------- + view_file_id + ID of the view to download. See :ref:`list_views`. If `None`, will download the latest view + destination_path + Full path to the destination file (including filename) + overwrite + If True, any existing file will be overwritten + """ + + my_views = self.list_views() + + if not my_views: + raise ValueError(f"No views available for this dataset") + + if view_file_id is None: + latest_view_ids = max(my_views, key=lambda x: x.created_on) + view_file_id = latest_view_ids.id + + view_map = {x.id: x for x in self.list_views()} + if view_file_id not in view_map: + raise ValueError(f"File id {view_file_id} is not a valid ID for this dataset") + + if destination_path is None: + view_data = view_map[view_file_id] + destination_path = os.path.join(os.getcwd(), view_data.file_name) + + self._client.download_external_file(view_file_id, destination_path, overwrite=overwrite) + + def use_view_cache( + self, + view_file_path: str, + ): + """ + Downloads and loads a view for this dataset + + Parameters + ---------- + view_file_path + Full path to the view file + """ + + cache_uri = f"file:{view_file_path}" + dcache = DatasetCache(cache_uri=cache_uri, read_only=False, dataset_type=type(self)) + + meta = dcache.get_metadata("dataset_metadata") + + if meta["id"] != self.id: + raise ValueError( + f"Info in view file does not match this dataset. ID in the file {meta['id']}, ID of this dataset {self.id}" + ) + + if meta["dataset_type"] != self.dataset_type: + raise ValueError( + f"Info in view file does not match this dataset. Dataset type in the file {meta['dataset_type']}, dataset type of this dataset {self.dataset_type}" + ) + + if meta["name"] != self.name: + raise ValueError( + f"Info in view file does not match this dataset. Dataset name in the file {meta['name']}, name of this dataset {self.name}" + ) + + self._cache_data = dcache + + def preload_cache(self, view_file_id: Optional[int] = None): + """ + Downloads a view file and uses it as the current cache + + Parameters + ---------- + view_file_id + ID of the view to download. See :ref:`list_views`. If `None`, will download the latest view + """ + + self.assert_is_not_view() + self.assert_online() + + if not self._client.cache.is_disk: + raise RuntimeError("Caching to disk is not enabled. Set the cache_dir path when constructing the client") + + destination_path = self._client.cache.get_dataset_cache_path(self.id) + self.download_view(view_file_id=view_file_id, destination_path=destination_path, overwrite=True) + self.use_view_cache(destination_path) + + def create_view( + self, + description: str, + provenance: Dict[str, Any], + status: Optional[Iterable[RecordStatusEnum]] = None, + include: Optional[Iterable[str]] = None, + exclude: Optional[Iterable[str]] = None, + *, + include_children: bool = True, + ) -> InternalJob: + """ + Creates a view of this dataset on the server + + This function will return an :ref:`InternalJob` which can be used to watch + for completion if desired. The job will run server side without user interaction. + + Note the ID field of the object if you with to retrieve this internal job later + (via :ref:`PortalClient.get_internal_job` and :ref:`get_internal_jobs`) + + Parameters + ---------- + description + Optional string describing the view file + provenance + Dictionary with any metadata or other information about the view. Information regarding + the options used to create the view will be added. + status + List of statuses to include. Default is to include records with any status + include + List of specific record fields to include in the export. Default is to include most fields + exclude + List of specific record fields to exclude from the export. Defaults to excluding none. + include_children + Specifies whether child records associated with the main records should also be included (recursively) + in the view file. + + Returns + ------- + : + An :ref:`InternalJob` object which can be used to watch for completion. + """ + + body = DatasetCreateViewBody( + description=description, + provenance=provenance, + status=status, + include=include, + exclude=exclude, + include_children=include_children, + ) + + job_id = self._client.make_request( + "post", f"api/v1/datasets/{self.dataset_type}/{self.id}/create_view", int, body=body + ) + + return self.get_internal_job(job_id) + ######################################### # Various properties and getters/setters ######################################### @@ -1786,6 +2021,15 @@ class DatasetFetchRecordsBody(RestModelBase): status: Optional[List[RecordStatusEnum]] = None +class DatasetCreateViewBody(RestModelBase): + description: Optional[str] + provenance: Dict[str, Any] + status: Optional[List[RecordStatusEnum]] = (None,) + include: Optional[List[str]] = (None,) + exclude: Optional[List[str]] = (None,) + include_children: bool = (True,) + + class DatasetSubmitBody(RestModelBase): entry_names: Optional[List[str]] = None specification_names: Optional[List[str]] = None @@ -1831,6 +2075,10 @@ class DatasetModifyEntryBody(RestModelBase): overwrite_attributes: bool = False +class DatasetGetInternalJobParams(RestModelBase): + status: Optional[List[InternalJobStatusEnum]] = None + + def dataset_from_dict(data: Dict[str, Any], client: Any, cache_data: Optional[DatasetCache] = None) -> BaseDataset: """ Create a dataset object from a datamodel diff --git a/qcportal/qcportal/dataset_testing_helpers.py b/qcportal/qcportal/dataset_testing_helpers.py index af0f71d4b..babb5bdf3 100644 --- a/qcportal/qcportal/dataset_testing_helpers.py +++ b/qcportal/qcportal/dataset_testing_helpers.py @@ -441,6 +441,9 @@ def run_dataset_model_submit(ds, test_entries, test_spec, record_compare): record_compare(rec, test_entries[0], test_spec) + assert rec.owner_user == "submit_user" + assert rec.owner_group == "group1" + # Used default tag/priority if rec.is_service: assert rec.service.tag == "default_tag" @@ -477,6 +480,14 @@ def run_dataset_model_submit(ds, test_entries, test_spec, record_compare): assert rec.task.tag == "default_tag" assert rec.task.priority == PriorityEnum.low + # Don't find existing + old_rec_id = rec.id + ds.remove_records(test_entries[2].name, "spec_1") + ds.submit(find_existing=False) + + rec = ds.get_record(test_entries[2].name, "spec_1") + assert rec.id != old_rec_id + record_count = len(ds.entry_names) * len(ds.specifications) assert ds.record_count == record_count assert ds._client.list_datasets()[0]["record_count"] == record_count diff --git a/qcportal/qcportal/external_files/__init__.py b/qcportal/qcportal/external_files/__init__.py new file mode 100644 index 000000000..632ac5a16 --- /dev/null +++ b/qcportal/qcportal/external_files/__init__.py @@ -0,0 +1 @@ +from .models import ExternalFileStatusEnum, ExternalFileTypeEnum, ExternalFile diff --git a/qcportal/qcportal/external_files/models.py b/qcportal/qcportal/external_files/models.py new file mode 100644 index 000000000..53eab4865 --- /dev/null +++ b/qcportal/qcportal/external_files/models.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Dict, Any, Optional + +try: + from pydantic.v1 import BaseModel, Extra, validator, PrivateAttr, Field +except ImportError: + from pydantic import BaseModel, Extra, validator, PrivateAttr, Field + + +class ExternalFileStatusEnum(str, Enum): + """ + The state of an external file + """ + + available = "available" + processing = "processing" + + +class ExternalFileTypeEnum(str, Enum): + """ + The state of an external file + """ + + dataset_attachment = "dataset_attachment" + + +class ExternalFile(BaseModel): + id: int + file_type: ExternalFileTypeEnum + + created_on: datetime + status: ExternalFileStatusEnum + + file_name: str + description: Optional[str] + provenance: Dict[str, Any] + + sha256sum: str + file_size: int diff --git a/qcportal/qcportal/gridoptimization/test_dataset_models.py b/qcportal/qcportal/gridoptimization/test_dataset_models.py index 8a292603b..f8e227d30 100644 --- a/qcportal/qcportal/gridoptimization/test_dataset_models.py +++ b/qcportal/qcportal/gridoptimization/test_dataset_models.py @@ -161,9 +161,13 @@ def test_gridoptimization_dataset_model_remove_record(snowflake_client: PortalCl ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_gridoptimization_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "gridoptimization", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_gridoptimization_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "gridoptimization", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/internal_jobs/models.py b/qcportal/qcportal/internal_jobs/models.py index d8c9ca672..a4e46b6c3 100644 --- a/qcportal/qcportal/internal_jobs/models.py +++ b/qcportal/qcportal/internal_jobs/models.py @@ -1,3 +1,4 @@ +import time from datetime import datetime from enum import Enum from typing import Optional, Dict, Any, List, Union @@ -5,12 +6,13 @@ from dateutil.parser import parse as date_parser try: - from pydantic.v1 import BaseModel, Extra, validator + from pydantic.v1 import BaseModel, Extra, validator, PrivateAttr except ImportError: - from pydantic import BaseModel, Extra, validator + from pydantic import BaseModel, Extra, validator, PrivateAttr from qcportal.base_models import QueryProjModelBase from ..base_models import QueryIteratorBase +from tqdm import tqdm class InternalJobStatusEnum(str, Enum): @@ -55,8 +57,10 @@ class Config: runner_hostname: Optional[str] runner_uuid: Optional[str] repeat_delay: Optional[int] + serial_group: Optional[str] progress: int + progress_description: Optional[str] = None function: str kwargs: Dict[str, Any] @@ -65,6 +69,85 @@ class Config: result: Any user: Optional[str] + _client: Any = PrivateAttr(None) + _refresh_url: Optional[str] = PrivateAttr(None) + + def __init__(self, client=None, refresh_url=None, **kwargs): + BaseModel.__init__(self, **kwargs) + self._client = client + self._refresh_url = refresh_url + + def refresh(self): + """ + Updates the data of this object with information from the server + """ + + if self._client is None: + raise RuntimeError("Client is not set") + + if self._refresh_url is None: + server_data = self._client.get_internal_job(self.id) + else: + server_data = self._client.make_request("get", self._refresh_url, InternalJob) + + for k, v in server_data: + setattr(self, k, v) + + def watch(self, interval: float = 2.0, timeout: Optional[float] = None): + """ + Watch an internal job for completion + + Will poll every `interval` seconds until the job is finished (complete, error, cancelled, etc). + + Parameters + ---------- + interval + Time (in seconds) between polls on the server + timeout + Max amount of time (in seconds) to wait. If None, will wait forever. + + Returns + ------- + + """ + + if self.status not in [InternalJobStatusEnum.waiting, InternalJobStatusEnum.running]: + return + + begin_time = time.time() + + end_time = None + if timeout is not None: + end_time = begin_time + timeout + + pbar = tqdm(initial=self.progress, total=100, desc=self.progress_description) + + while True: + t = time.time() + + self.refresh() + pbar.update(self.progress - pbar.n) + pbar.set_description(self.progress_description) + + if end_time is not None and t >= end_time: + raise TimeoutError("Timed out waiting for job to complete") + + if self.status == InternalJobStatusEnum.error: + print("Internal job resulted in an error:") + print(self.result) + break + elif self.status not in [InternalJobStatusEnum.waiting, InternalJobStatusEnum.running]: + print(f"Internal job final status: {self.status.value}") + break + + curtime = time.time() + + if end_time is not None: + # sleep the normal interval, or up to the timeout time + time.sleep(min(interval, end_time - curtime + 0.1)) + else: + time.sleep(interval) + class InternalJobQueryFilters(QueryProjModelBase): job_id: Optional[List[int]] = None @@ -118,9 +201,11 @@ def __init__(self, client, query_filters: InternalJobQueryFilters): QueryIteratorBase.__init__(self, client, query_filters, batch_limit) def _request(self) -> List[InternalJob]: - return self._client.make_request( + ij_dicts = self._client.make_request( "post", "api/v1/internal_jobs/query", - List[InternalJob], + List[Dict[str, Any]], body=self._query_filters, ) + + return [InternalJob(client=self, **ij_dict) for ij_dict in ij_dicts] diff --git a/qcportal/qcportal/manybody/test_dataset_models.py b/qcportal/qcportal/manybody/test_dataset_models.py index 119dd8985..ee2ff9598 100644 --- a/qcportal/qcportal/manybody/test_dataset_models.py +++ b/qcportal/qcportal/manybody/test_dataset_models.py @@ -117,9 +117,13 @@ def test_manybody_dataset_model_remove_record(snowflake_client: PortalClient): ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_manybody_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "manybody", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_manybody_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "manybody", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/neb/test_dataset_models.py b/qcportal/qcportal/neb/test_dataset_models.py index e5c4ec587..668b8633c 100644 --- a/qcportal/qcportal/neb/test_dataset_models.py +++ b/qcportal/qcportal/neb/test_dataset_models.py @@ -163,9 +163,9 @@ def test_neb_dataset_model_remove_record(snowflake_client: PortalClient): ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_neb_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "neb", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_neb_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "neb", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low, owner_group="group1" ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/optimization/test_dataset_models.py b/qcportal/qcportal/optimization/test_dataset_models.py index 7a48e678d..473275992 100644 --- a/qcportal/qcportal/optimization/test_dataset_models.py +++ b/qcportal/qcportal/optimization/test_dataset_models.py @@ -119,9 +119,13 @@ def test_optimization_dataset_model_remove_record(snowflake_client: PortalClient ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_optimization_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "optimization", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_optimization_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "optimization", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/reaction/test_dataset_models.py b/qcportal/qcportal/reaction/test_dataset_models.py index 51ee7fd2b..a1701f00f 100644 --- a/qcportal/qcportal/reaction/test_dataset_models.py +++ b/qcportal/qcportal/reaction/test_dataset_models.py @@ -139,9 +139,13 @@ def test_reaction_dataset_model_remove_record(snowflake_client: PortalClient): ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_reaction_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "reaction", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_reaction_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "reaction", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/singlepoint/test_dataset_models.py b/qcportal/qcportal/singlepoint/test_dataset_models.py index f4c8907ab..7bd1fa957 100644 --- a/qcportal/qcportal/singlepoint/test_dataset_models.py +++ b/qcportal/qcportal/singlepoint/test_dataset_models.py @@ -104,9 +104,13 @@ def test_singlepoint_dataset_model_remove_record(snowflake_client: PortalClient) ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_singlepoint_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "singlepoint", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_singlepoint_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "singlepoint", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/torsiondrive/test_dataset_models.py b/qcportal/qcportal/torsiondrive/test_dataset_models.py index 7788d3c99..dffd10207 100644 --- a/qcportal/qcportal/torsiondrive/test_dataset_models.py +++ b/qcportal/qcportal/torsiondrive/test_dataset_models.py @@ -166,9 +166,13 @@ def test_torsiondrive_dataset_model_remove_record(snowflake_client: PortalClient ds_helpers.run_dataset_model_remove_record(snowflake_client, ds, test_entries, test_specs) -def test_torsiondrive_dataset_model_submit(snowflake_client: PortalClient): - ds = snowflake_client.add_dataset( - "torsiondrive", "Test dataset", default_tag="default_tag", default_priority=PriorityEnum.low +def test_torsiondrive_dataset_model_submit(submitter_client: PortalClient): + ds = submitter_client.add_dataset( + "torsiondrive", + "Test dataset", + default_tag="default_tag", + default_priority=PriorityEnum.low, + owner_group="group1", ) ds_helpers.run_dataset_model_submit(ds, test_entries, test_specs[0], record_compare) diff --git a/qcportal/qcportal/utils.py b/qcportal/qcportal/utils.py index a60f0ccb0..338acd173 100644 --- a/qcportal/qcportal/utils.py +++ b/qcportal/qcportal/utils.py @@ -440,3 +440,12 @@ def is_included(key: str, include: Optional[Iterable[str]], exclude: Optional[It exclude = tuple(sorted(exclude)) return _is_included(key, include, exclude, default) + + +def update_nested_dict(d: Dict[str, Any], u: Dict[str, Any]): + for k, v in u.items(): + if isinstance(v, dict): + d[k] = update_nested_dict(d.get(k, {}), v) + else: + d[k] = v + return d