diff --git a/qcfractal/qcfractal/components/dataset_routes.py b/qcfractal/qcfractal/components/dataset_routes.py index d1d39062a..51999443b 100644 --- a/qcfractal/qcfractal/components/dataset_routes.py +++ b/qcfractal/qcfractal/components/dataset_routes.py @@ -11,6 +11,7 @@ DatasetQueryModel, DatasetFetchRecordsBody, DatasetFetchEntryBody, + DatasetFetchSpecificationBody, DatasetSubmitBody, DatasetDeleteStrBody, DatasetRecordModifyBody, @@ -181,13 +182,37 @@ def submit_dataset_v1(dataset_type: str, dataset_id: int, body_data: DatasetSubm ################### # Specifications ################### +@api_v1.route("/datasets///specification_names", methods=["GET"]) +@wrap_route("READ") +def fetch_dataset_specification_names_v1(dataset_type: str, dataset_id: int): + ds_socket = storage_socket.datasets.get_socket(dataset_type) + return ds_socket.fetch_specification_names(dataset_id) + + @api_v1.route("/datasets///specifications", methods=["GET"]) @wrap_route("READ") -def fetch_dataset_specifications_v1(dataset_type: str, dataset_id: int): +def fetch_all_dataset_specifications_v1(dataset_type: str, dataset_id: int): ds_socket = storage_socket.datasets.get_socket(dataset_type) return ds_socket.fetch_specifications(dataset_id) +@api_v1.route("/datasets///specifications/bulkFetch", methods=["POST"]) +@wrap_route("READ") +def fetch_dataset_specifications_v1(dataset_type: str, dataset_id: int, body_data: DatasetFetchSpecificationBody): + # use the entry limit I guess? + limit = current_app.config["QCFRACTAL_CONFIG"].api_limits.get_dataset_entries + + if len(body_data.names) > limit: + raise LimitExceededError(f"Cannot get {len(body_data.names)} dataset specifications - limit is {limit}") + + ds_socket = storage_socket.datasets.get_socket(dataset_type) + return ds_socket.fetch_specifications( + dataset_id, + specification_names=body_data.names, + missing_ok=body_data.missing_ok, + ) + + @api_v1.route("/datasets///specifications/bulkDelete", methods=["POST"]) @wrap_route("DELETE") def delete_dataset_specifications_v1(dataset_type: str, dataset_id: int, body_data: DatasetDeleteStrBody): diff --git a/qcfractal/qcfractal/components/dataset_socket.py b/qcfractal/qcfractal/components/dataset_socket.py index b6f8ed613..124452f09 100644 --- a/qcfractal/qcfractal/components/dataset_socket.py +++ b/qcfractal/qcfractal/components/dataset_socket.py @@ -536,16 +536,47 @@ def add_specifications( return InsertMetadata(inserted_idx=inserted_idx, existing_idx=existing_idx) + def fetch_specification_names( + self, + dataset_id: int, + *, + session: Optional[Session] = None, + ) -> List[str]: + """ + Obtain all specification names for a dataset + + Parameters + ---------- + dataset_id + ID of a dataset + session + An existing SQLAlchemy session to use. If None, one will be created. If an existing session + is used, it will be flushed (but not committed) before returning from this function. + + Returns + ------- + : + All entry names as a list + """ + stmt = select(self.specification_orm.name) + stmt = stmt.where(self.specification_orm.dataset_id == dataset_id) + + with self.root_socket.optional_session(session, True) as session: + ret = session.execute(stmt).scalars().all() + return list(ret) + def fetch_specifications( self, dataset_id: int, + specification_names: Optional[Sequence[str]] = None, include: Optional[Iterable[str]] = None, exclude: Optional[Iterable[str]] = None, + missing_ok: bool = False, *, session: Optional[Session] = None, ) -> Dict[str, Any]: """ - Get specifications for a dataset from the database + Fetch specifications for a dataset from the database It's expected there aren't too many specifications, so this always fetches them all. @@ -553,6 +584,8 @@ def fetch_specifications( ---------- dataset_id ID of a dataset + specification_names + Names of the specifications to fetch. If None, fetch all include Which fields of the result to return. Default is to return all fields. exclude @@ -572,14 +605,24 @@ def fetch_specifications( stmt = stmt.where(self.specification_orm.dataset_id == dataset_id) stmt = stmt.options(joinedload(self.specification_orm.specification)) + if specification_names is not None: + stmt = stmt.where(self.specification_orm.name.in_(specification_names)) + if include or exclude: - query_opts = get_query_proj_options(self.entry_orm, include, exclude) + query_opts = get_query_proj_options(self.specification_orm, include, exclude) stmt = stmt.options(*query_opts) with self.root_socket.optional_session(session, True) as session: - entries = session.execute(stmt).scalars().all() + specifications = session.execute(stmt).scalars().all() - return {x.name: x.model_dict() for x in entries} + if specification_names is not None and missing_ok is False: + found_specifications = {x.name for x in specifications} + missing_specifications = set(specification_names) - found_specifications + if missing_specifications: + s = "\n".join(missing_specifications) + raise MissingDataError(f"Missing {len(missing_specifications)} specifications: {s}") + + return {x.name: x.model_dict() for x in specifications} def delete_specifications( self, diff --git a/qcportal/qcportal/dataset_models.py b/qcportal/qcportal/dataset_models.py index 8af7d745c..95c52398a 100644 --- a/qcportal/qcportal/dataset_models.py +++ b/qcportal/qcportal/dataset_models.py @@ -1368,6 +1368,11 @@ class DatasetQueryModel(RestModelBase): exclude: Optional[List[str]] = None +class DatasetFetchSpecificationBody(RestModelBase): + names: List[str] + missing_ok: bool = False + + class DatasetFetchEntryBody(RestModelBase): names: List[str] missing_ok: bool = False