From 27c342575439110e29405e44b4b385e04c03dbc2 Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Thu, 18 Jan 2024 15:39:57 -0500 Subject: [PATCH] Get child records via cache --- qcportal/qcportal/cache.py | 20 +++---- qcportal/qcportal/dataset_models.py | 1 - .../gridoptimization/record_models.py | 19 +++--- qcportal/qcportal/manybody/record_models.py | 31 ++++++---- qcportal/qcportal/neb/record_models.py | 30 ++++++---- .../qcportal/optimization/record_models.py | 20 +++---- qcportal/qcportal/reaction/record_models.py | 32 +++++----- qcportal/qcportal/record_models.py | 59 +++++++++++++++++++ .../qcportal/torsiondrive/record_models.py | 23 ++++---- 9 files changed, 152 insertions(+), 83 deletions(-) diff --git a/qcportal/qcportal/cache.py b/qcportal/qcportal/cache.py index 0a0ec9bab..2bbccb3fc 100644 --- a/qcportal/qcportal/cache.py +++ b/qcportal/qcportal/cache.py @@ -101,9 +101,8 @@ def get_record(self, record_id: int, record_type: Type[_RECORD_T]) -> Optional[_ record = decompress_from_cache(record_data[1], record_type) - if not self.read_only: - record._record_cache = self - record._record_cache_uid = record_data[0] + record._record_cache = self + record._record_cache_uid = record_data[0] return record @@ -119,9 +118,8 @@ def get_records(self, record_ids: Iterable[int], record_type: Type[_RECORD_T]) - for uid, compressed_record in rdata: record = decompress_from_cache(compressed_record, record_type) - if not self.read_only: - record._record_cache = self - record._record_cache_uid = uid + record._record_cache = self + record._record_cache_uid = uid all_records.append(record) @@ -370,9 +368,8 @@ def get_dataset_record(self, entry_name: str, specification_name: str): record = decompress_from_cache(record_data[1], self._record_type) - if not self.read_only: - record._record_cache = self - record._record_cache_uid = record_data[0] + record._record_cache = self + record._record_cache_uid = record_data[0] return record @@ -395,9 +392,8 @@ def get_dataset_records(self, entry_names: Iterable[str], specification_names: I for uid, ename, sname, compressed_record in rdata: record = decompress_from_cache(compressed_record, self._record_type) - if not self.read_only: - record._record_cache = self - record._record_cache_uid = uid + record._record_cache = self + record._record_cache_uid = uid all_records.append((ename, sname, record)) diff --git a/qcportal/qcportal/dataset_models.py b/qcportal/qcportal/dataset_models.py index 907271393..ed2b38dd9 100644 --- a/qcportal/qcportal/dataset_models.py +++ b/qcportal/qcportal/dataset_models.py @@ -1131,7 +1131,6 @@ def iterate_records( existing_records = self._cache_data.get_existing_dataset_records(entry_names_batch, [spec_name]) existing_entries = [x[0] for x in existing_records] batch_tofetch = [x for x in entry_names_batch if x not in existing_entries] - # print(f"BATCH TO FETCH: {len(batch_tofetch)}") if batch_tofetch: self._internal_fetch_records(batch_tofetch, [spec_name], status, include) diff --git a/qcportal/qcportal/gridoptimization/record_models.py b/qcportal/qcportal/gridoptimization/record_models.py index caea7f584..d49458cbf 100644 --- a/qcportal/qcportal/gridoptimization/record_models.py +++ b/qcportal/qcportal/gridoptimization/record_models.py @@ -227,17 +227,18 @@ def _fetch_starting_molecule(self): self.starting_molecule_ = self._client.get_molecules([self.starting_molecule_id])[0] def _fetch_optimizations(self): - self._assert_online() - - self.optimizations_ = self._client.make_request( - "get", - f"api/v1/records/gridoptimization/{self.id}/optimizations", - List[GridoptimizationOptimization], - ) + # Always fetch optimization metadata if we can + if not self.offline or self.optimizations_ is None: + self._assert_online() + self.optimizations_ = self._client.make_request( + "get", + f"api/v1/records/gridoptimization/{self.id}/optimizations", + List[GridoptimizationOptimization], + ) - # Fetch optimization records from the server + # Fetch optimization records from the server or the cache opt_ids = [x.optimization_id for x in self.optimizations_] - opt_records = self._client.get_optimizations(opt_ids) + opt_records = self._get_child_records(opt_ids, OptimizationRecord) self._optimizations_cache = {deserialize_key(x.key): y for x, y in zip(self.optimizations_, opt_records)} diff --git a/qcportal/qcportal/manybody/record_models.py b/qcportal/qcportal/manybody/record_models.py index 75ada58e9..70ad1f4d2 100644 --- a/qcportal/qcportal/manybody/record_models.py +++ b/qcportal/qcportal/manybody/record_models.py @@ -112,24 +112,31 @@ def _fetch_initial_molecule(self): def _fetch_clusters(self): self._assert_online() - self._clusters = self._client.make_request( - "get", - f"api/v1/records/manybody/{self.id}/clusters", - List[ManybodyCluster], - ) + if not self.offline or self._clusters is None: + self._clusters = self._client.make_request( + "get", + f"api/v1/records/manybody/{self.id}/clusters", + List[ManybodyCluster], + ) + + mol_ids = [x.molecule_id for x in self._clusters] + mols = self._client.get_molecules(mol_ids) + + for cluster, mol in zip(self._clusters, mols): + assert mol.id == cluster.molecule_id + cluster.molecule = mol # Fetch singlepoint records and molecules sp_ids = [x.singlepoint_id for x in self._clusters] - sp_recs = self._client.get_singlepoints(sp_ids) - - mol_ids = [x.molecule_id for x in self._clusters] - mols = self._client.get_molecules(mol_ids) + sp_recs = self._get_child_records(sp_ids, SinglepointRecord) - for cluster, sp, mol in zip(self._clusters, sp_recs, mols): + for ( + cluster, + sp, + ) in zip(self._clusters, sp_recs): assert sp.id == cluster.singlepoint_id - assert sp.molecule_id == mol.id == cluster.molecule_id + assert sp.molecule_id == cluster.molecule_id cluster.singlepoint_record = sp - cluster.molecule = mol self.propagate_client(self._client) diff --git a/qcportal/qcportal/neb/record_models.py b/qcportal/qcportal/neb/record_models.py index 6f9691555..193d6a491 100644 --- a/qcportal/qcportal/neb/record_models.py +++ b/qcportal/qcportal/neb/record_models.py @@ -172,15 +172,17 @@ def fetch_all(self): def _fetch_optimizations(self): self._assert_online() - self.optimizations_ = self._client.make_request( - "get", - f"api/v1/records/neb/{self.id}/optimizations", - Dict[str, NEBOptimization], - ) + if not self.offline or self.optimizations_ is None: + self._assert_online() + self.optimizations_ = self._client.make_request( + "get", + f"api/v1/records/neb/{self.id}/optimizations", + Dict[str, NEBOptimization], + ) # Fetch optimization records from server opt_ids = [opt.optimization_id for opt in self.optimizations_.values()] - opt_recs = self._client.get_optimizations(opt_ids) + opt_recs = self._get_child_records(opt_ids, OptimizationRecord) opt_map = {opt.id: opt for opt in opt_recs} self._optimizations_cache = {} @@ -193,15 +195,17 @@ def _fetch_optimizations(self): def _fetch_singlepoints(self): self._assert_online() - self.singlepoints_ = self._client.make_request( - "get", - f"api/v1/records/neb/{self.id}/singlepoints", - List[NEBSinglepoint], - ) + if not self.offline or self.singlepoints_ is None: + self._assert_online() + self.singlepoints_ = self._client.make_request( + "get", + f"api/v1/records/neb/{self.id}/singlepoints", + List[NEBSinglepoint], + ) - # Fetch singlepoint records from server + # Fetch singlepoint records from server or the cache sp_ids = [sp.singlepoint_id for sp in self.singlepoints_] - sp_recs = self._client.get_singlepoints(sp_ids) + sp_recs = self._get_child_records(sp_ids, SinglepointRecord) self._singlepoints_cache = {} diff --git a/qcportal/qcportal/optimization/record_models.py b/qcportal/qcportal/optimization/record_models.py index a72c800d1..1fac69ee0 100644 --- a/qcportal/qcportal/optimization/record_models.py +++ b/qcportal/qcportal/optimization/record_models.py @@ -83,20 +83,20 @@ def _fetch_initial_molecule(self): self.initial_molecule_ = self._client.get_molecules([self.initial_molecule_id])[0] def _fetch_final_molecule(self): - self._assert_online() if self.final_molecule_id is not None: + self._assert_online() self.final_molecule_ = self._client.get_molecules([self.final_molecule_id])[0] def _fetch_trajectory(self): - self._assert_online() - - self.trajectory_ids_ = self._client.make_request( - "get", - f"api/v1/records/optimization/{self.id}/trajectory", - List[int], - ) + if self.trajectory_ids_ is None: + self._assert_online() + self.trajectory_ids_ = self._client.make_request( + "get", + f"api/v1/records/optimization/{self.id}/trajectory", + List[int], + ) - self._trajectory_records = self._client.get_singlepoints(self.trajectory_ids_) + self._trajectory_records = self._get_child_records(self.trajectory_ids_, SinglepointRecord) self.propagate_client(self._client) def _handle_includes(self, includes: Optional[Iterable[str]]): @@ -146,7 +146,7 @@ def trajectory_element(self, trajectory_index: int) -> SinglepointRecord: if self.trajectory_ids_ is not None: traj_id = self.trajectory_ids_[trajectory_index] - sp_rec = self._client.get_singlepoints(traj_id) + sp_rec = self._get_child_records([traj_id], SinglepointRecord)[0] sp_rec.propagate_client(self._client) return sp_rec else: diff --git a/qcportal/qcportal/reaction/record_models.py b/qcportal/qcportal/reaction/record_models.py index 15ecdf77c..f7847a817 100644 --- a/qcportal/qcportal/reaction/record_models.py +++ b/qcportal/qcportal/reaction/record_models.py @@ -109,32 +109,34 @@ def fetch_all(self): def _fetch_components(self): self._assert_online() - self._components = self._client.make_request( - "get", - f"api/v1/records/reaction/{self.id}/components", - List[ReactionComponent], - ) - - # Fetch records & molecules + if not self.offline or self._components is None: + self._assert_online() + self._components = self._client.make_request( + "get", + f"api/v1/records/reaction/{self.id}/components", + List[ReactionComponent], + ) + + mol_ids = [c.molecule_id for c in self._components] + mols = self._client.get_molecules(mol_ids) + for c, mol in zip(self._components, mols): + assert mol.id == c.molecule_id + c.molecule = mol + + # Fetch records from server or cache sp_comp = [c for c in self._components if c.singlepoint_id is not None] sp_ids = [c.singlepoint_id for c in sp_comp] - sp_recs = self._client.get_singlepoints(sp_ids) + sp_recs = self._get_child_records(sp_ids, SinglepointRecord) for c, rec in zip(sp_comp, sp_recs): c.singlepoint_record = rec opt_comp = [c for c in self._components if c.optimization_id is not None] opt_ids = [c.optimization_id for c in opt_comp] - opt_recs = self._client.get_optimizations(opt_ids) + opt_recs = self._get_child_records(opt_ids, OptimizationRecord) for c, rec in zip(opt_comp, opt_recs): assert rec.initial_molecule_id == c.molecule_id c.optimization_record = rec - mol_ids = [c.molecule_id for c in self._components] - mols = self._client.get_molecules(mol_ids) - for c, mol in zip(self._components, mols): - assert mol.id == c.molecule_id - c.molecule = mol - self.propagate_client(self._client) def _handle_includes(self, includes: Optional[Iterable[str]]): diff --git a/qcportal/qcportal/record_models.py b/qcportal/qcportal/record_models.py index 276e18719..70c8a51d4 100644 --- a/qcportal/qcportal/record_models.py +++ b/qcportal/qcportal/record_models.py @@ -482,6 +482,65 @@ def fetch_all(self): for nf in self.native_files_.values(): nf.fetch_all() + def _cache_child_records(self, record_ids: Iterable[int], record_type: Type[_Record_T]) -> None: + """ + Fetching child records and stores them in the cache + + The cache will be checked for existing records + """ + + if self._record_cache is None: + return + + existing_record_ids = self._record_cache.get_existing_records(record_ids) + records_tofetch = list(set(record_ids) - set(existing_record_ids)) + + if self.offline and records_tofetch: + raise RuntimeError("Need to fetch some records, but not connected to a client") + + recs = self._client._get_records_by_type(record_type, records_tofetch) + self._record_cache.update_records(recs) + + def _get_child_records(self, record_ids: Sequence[int], record_type: Type[_Record_T]) -> List[_Record_T]: + """ + Helper function for obtaining child records either from the cache or from the server + + The records are returned in the same order as the `record_ids` parameter. + """ + + if self._record_cache is None: + self._assert_online() + existing_records = [] + records_tofetch = record_ids + else: + existing_records = self._record_cache.get_records(record_ids, record_type) + records_tofetch = set(record_ids) - {x.id for x in existing_records} + + if self.offline and records_tofetch: + raise RuntimeError("Need to fetch some records, but not connected to a client") + + if records_tofetch: + recs = self._client._get_records_by_type(record_type, list(records_tofetch)) + if self._record_cache is not None: + uids = self._record_cache.update_records(recs) + + for u, r in zip(uids, recs): + r._record_cache = self._record_cache + r._record_cache_uid = u + + existing_records += recs + + # Return everything in the same order as the input + all_recs = {r.id: r for r in existing_records} + ret = [all_recs.get(rid, None) for rid in record_ids] + if None in ret: + missing_ids = set(record_ids) - set(all_recs.keys()) + raise RuntimeError( + f"Not all records found either in the cache or on the server. Missing records: {missing_ids}" + ) + + return ret + def _assert_online(self): """Raises an exception if this record does not have an associated client""" if self.offline: diff --git a/qcportal/qcportal/torsiondrive/record_models.py b/qcportal/qcportal/torsiondrive/record_models.py index 2e9a50574..588b9e90e 100644 --- a/qcportal/qcportal/torsiondrive/record_models.py +++ b/qcportal/qcportal/torsiondrive/record_models.py @@ -177,17 +177,18 @@ def _fetch_initial_molecules(self): self.initial_molecules_ = self._client.get_molecules(self.initial_molecules_ids_) def _fetch_optimizations(self): - self._assert_online() - - self.optimizations_ = self._client.make_request( - "get", - f"api/v1/records/torsiondrive/{self.id}/optimizations", - List[TorsiondriveOptimization], - ) - - # Fetch optimization records from the server + # Always fetch optimization metadata if we can + if not self.offline or self.optimizations_ is None: + self._assert_online() + self.optimizations_ = self._client.make_request( + "get", + f"api/v1/records/torsiondrive/{self.id}/optimizations", + List[TorsiondriveOptimization], + ) + + # Fetch optimization records from the server or cache opt_ids = [x.optimization_id for x in self.optimizations_] - opt_records = self._client.get_optimizations(opt_ids) + opt_records = self._get_child_records(opt_ids, OptimizationRecord) self._optimizations_cache = {} for td_opt, opt_record in zip(self.optimizations_, opt_records): @@ -218,7 +219,7 @@ def _fetch_minimum_optimizations(self): # Fetch optimization records from the server opt_key_ids = list(min_opt_ids.items()) opt_ids = [x[1] for x in opt_key_ids] - opt_records = self._client.get_optimizations(opt_ids) + opt_records = self._get_child_records(opt_ids, OptimizationRecord) self._minimum_optimizations_cache = {deserialize_key(x[0]): y for x, y in zip(opt_key_ids, opt_records)}