Skip to content

Commit

Permalink
Get child records via cache
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed Jan 22, 2024
1 parent ab3e05c commit 27c3425
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 83 deletions.
20 changes: 8 additions & 12 deletions qcportal/qcportal/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ def get_record(self, record_id: int, record_type: Type[_RECORD_T]) -> Optional[_

record = decompress_from_cache(record_data[1], record_type)

if not self.read_only:
record._record_cache = self
record._record_cache_uid = record_data[0]
record._record_cache = self
record._record_cache_uid = record_data[0]

return record

Expand All @@ -119,9 +118,8 @@ def get_records(self, record_ids: Iterable[int], record_type: Type[_RECORD_T]) -
for uid, compressed_record in rdata:
record = decompress_from_cache(compressed_record, record_type)

if not self.read_only:
record._record_cache = self
record._record_cache_uid = uid
record._record_cache = self
record._record_cache_uid = uid

all_records.append(record)

Expand Down Expand Up @@ -370,9 +368,8 @@ def get_dataset_record(self, entry_name: str, specification_name: str):

record = decompress_from_cache(record_data[1], self._record_type)

if not self.read_only:
record._record_cache = self
record._record_cache_uid = record_data[0]
record._record_cache = self
record._record_cache_uid = record_data[0]

return record

Expand All @@ -395,9 +392,8 @@ def get_dataset_records(self, entry_names: Iterable[str], specification_names: I
for uid, ename, sname, compressed_record in rdata:
record = decompress_from_cache(compressed_record, self._record_type)

if not self.read_only:
record._record_cache = self
record._record_cache_uid = uid
record._record_cache = self
record._record_cache_uid = uid

all_records.append((ename, sname, record))

Expand Down
1 change: 0 additions & 1 deletion qcportal/qcportal/dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,6 @@ def iterate_records(
existing_records = self._cache_data.get_existing_dataset_records(entry_names_batch, [spec_name])
existing_entries = [x[0] for x in existing_records]
batch_tofetch = [x for x in entry_names_batch if x not in existing_entries]
# print(f"BATCH TO FETCH: {len(batch_tofetch)}")

if batch_tofetch:
self._internal_fetch_records(batch_tofetch, [spec_name], status, include)
Expand Down
19 changes: 10 additions & 9 deletions qcportal/qcportal/gridoptimization/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,18 @@ def _fetch_starting_molecule(self):
self.starting_molecule_ = self._client.get_molecules([self.starting_molecule_id])[0]

def _fetch_optimizations(self):
self._assert_online()

self.optimizations_ = self._client.make_request(
"get",
f"api/v1/records/gridoptimization/{self.id}/optimizations",
List[GridoptimizationOptimization],
)
# Always fetch optimization metadata if we can
if not self.offline or self.optimizations_ is None:
self._assert_online()
self.optimizations_ = self._client.make_request(
"get",
f"api/v1/records/gridoptimization/{self.id}/optimizations",
List[GridoptimizationOptimization],
)

# Fetch optimization records from the server
# Fetch optimization records from the server or the cache
opt_ids = [x.optimization_id for x in self.optimizations_]
opt_records = self._client.get_optimizations(opt_ids)
opt_records = self._get_child_records(opt_ids, OptimizationRecord)

self._optimizations_cache = {deserialize_key(x.key): y for x, y in zip(self.optimizations_, opt_records)}

Expand Down
31 changes: 19 additions & 12 deletions qcportal/qcportal/manybody/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,31 @@ def _fetch_initial_molecule(self):
def _fetch_clusters(self):
self._assert_online()

self._clusters = self._client.make_request(
"get",
f"api/v1/records/manybody/{self.id}/clusters",
List[ManybodyCluster],
)
if not self.offline or self._clusters is None:
self._clusters = self._client.make_request(
"get",
f"api/v1/records/manybody/{self.id}/clusters",
List[ManybodyCluster],
)

mol_ids = [x.molecule_id for x in self._clusters]
mols = self._client.get_molecules(mol_ids)

for cluster, mol in zip(self._clusters, mols):
assert mol.id == cluster.molecule_id
cluster.molecule = mol

# Fetch singlepoint records and molecules
sp_ids = [x.singlepoint_id for x in self._clusters]
sp_recs = self._client.get_singlepoints(sp_ids)

mol_ids = [x.molecule_id for x in self._clusters]
mols = self._client.get_molecules(mol_ids)
sp_recs = self._get_child_records(sp_ids, SinglepointRecord)

for cluster, sp, mol in zip(self._clusters, sp_recs, mols):
for (
cluster,
sp,
) in zip(self._clusters, sp_recs):
assert sp.id == cluster.singlepoint_id
assert sp.molecule_id == mol.id == cluster.molecule_id
assert sp.molecule_id == cluster.molecule_id
cluster.singlepoint_record = sp
cluster.molecule = mol

self.propagate_client(self._client)

Expand Down
30 changes: 17 additions & 13 deletions qcportal/qcportal/neb/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,17 @@ def fetch_all(self):
def _fetch_optimizations(self):
self._assert_online()

self.optimizations_ = self._client.make_request(
"get",
f"api/v1/records/neb/{self.id}/optimizations",
Dict[str, NEBOptimization],
)
if not self.offline or self.optimizations_ is None:
self._assert_online()
self.optimizations_ = self._client.make_request(
"get",
f"api/v1/records/neb/{self.id}/optimizations",
Dict[str, NEBOptimization],
)

# Fetch optimization records from server
opt_ids = [opt.optimization_id for opt in self.optimizations_.values()]
opt_recs = self._client.get_optimizations(opt_ids)
opt_recs = self._get_child_records(opt_ids, OptimizationRecord)
opt_map = {opt.id: opt for opt in opt_recs}

self._optimizations_cache = {}
Expand All @@ -193,15 +195,17 @@ def _fetch_optimizations(self):
def _fetch_singlepoints(self):
self._assert_online()

self.singlepoints_ = self._client.make_request(
"get",
f"api/v1/records/neb/{self.id}/singlepoints",
List[NEBSinglepoint],
)
if not self.offline or self.singlepoints_ is None:
self._assert_online()
self.singlepoints_ = self._client.make_request(
"get",
f"api/v1/records/neb/{self.id}/singlepoints",
List[NEBSinglepoint],
)

# Fetch singlepoint records from server
# Fetch singlepoint records from server or the cache
sp_ids = [sp.singlepoint_id for sp in self.singlepoints_]
sp_recs = self._client.get_singlepoints(sp_ids)
sp_recs = self._get_child_records(sp_ids, SinglepointRecord)

self._singlepoints_cache = {}

Expand Down
20 changes: 10 additions & 10 deletions qcportal/qcportal/optimization/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,20 @@ def _fetch_initial_molecule(self):
self.initial_molecule_ = self._client.get_molecules([self.initial_molecule_id])[0]

def _fetch_final_molecule(self):
self._assert_online()
if self.final_molecule_id is not None:
self._assert_online()
self.final_molecule_ = self._client.get_molecules([self.final_molecule_id])[0]

def _fetch_trajectory(self):
self._assert_online()

self.trajectory_ids_ = self._client.make_request(
"get",
f"api/v1/records/optimization/{self.id}/trajectory",
List[int],
)
if self.trajectory_ids_ is None:
self._assert_online()
self.trajectory_ids_ = self._client.make_request(
"get",
f"api/v1/records/optimization/{self.id}/trajectory",
List[int],
)

self._trajectory_records = self._client.get_singlepoints(self.trajectory_ids_)
self._trajectory_records = self._get_child_records(self.trajectory_ids_, SinglepointRecord)
self.propagate_client(self._client)

def _handle_includes(self, includes: Optional[Iterable[str]]):
Expand Down Expand Up @@ -146,7 +146,7 @@ def trajectory_element(self, trajectory_index: int) -> SinglepointRecord:

if self.trajectory_ids_ is not None:
traj_id = self.trajectory_ids_[trajectory_index]
sp_rec = self._client.get_singlepoints(traj_id)
sp_rec = self._get_child_records([traj_id], SinglepointRecord)[0]
sp_rec.propagate_client(self._client)
return sp_rec
else:
Expand Down
32 changes: 17 additions & 15 deletions qcportal/qcportal/reaction/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,32 +109,34 @@ def fetch_all(self):
def _fetch_components(self):
self._assert_online()

self._components = self._client.make_request(
"get",
f"api/v1/records/reaction/{self.id}/components",
List[ReactionComponent],
)

# Fetch records & molecules
if not self.offline or self._components is None:
self._assert_online()
self._components = self._client.make_request(
"get",
f"api/v1/records/reaction/{self.id}/components",
List[ReactionComponent],
)

mol_ids = [c.molecule_id for c in self._components]
mols = self._client.get_molecules(mol_ids)
for c, mol in zip(self._components, mols):
assert mol.id == c.molecule_id
c.molecule = mol

# Fetch records from server or cache
sp_comp = [c for c in self._components if c.singlepoint_id is not None]
sp_ids = [c.singlepoint_id for c in sp_comp]
sp_recs = self._client.get_singlepoints(sp_ids)
sp_recs = self._get_child_records(sp_ids, SinglepointRecord)
for c, rec in zip(sp_comp, sp_recs):
c.singlepoint_record = rec

opt_comp = [c for c in self._components if c.optimization_id is not None]
opt_ids = [c.optimization_id for c in opt_comp]
opt_recs = self._client.get_optimizations(opt_ids)
opt_recs = self._get_child_records(opt_ids, OptimizationRecord)
for c, rec in zip(opt_comp, opt_recs):
assert rec.initial_molecule_id == c.molecule_id
c.optimization_record = rec

mol_ids = [c.molecule_id for c in self._components]
mols = self._client.get_molecules(mol_ids)
for c, mol in zip(self._components, mols):
assert mol.id == c.molecule_id
c.molecule = mol

self.propagate_client(self._client)

def _handle_includes(self, includes: Optional[Iterable[str]]):
Expand Down
59 changes: 59 additions & 0 deletions qcportal/qcportal/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,65 @@ def fetch_all(self):
for nf in self.native_files_.values():
nf.fetch_all()

def _cache_child_records(self, record_ids: Iterable[int], record_type: Type[_Record_T]) -> None:
"""
Fetching child records and stores them in the cache
The cache will be checked for existing records
"""

if self._record_cache is None:
return

existing_record_ids = self._record_cache.get_existing_records(record_ids)
records_tofetch = list(set(record_ids) - set(existing_record_ids))

if self.offline and records_tofetch:
raise RuntimeError("Need to fetch some records, but not connected to a client")

recs = self._client._get_records_by_type(record_type, records_tofetch)
self._record_cache.update_records(recs)

def _get_child_records(self, record_ids: Sequence[int], record_type: Type[_Record_T]) -> List[_Record_T]:
"""
Helper function for obtaining child records either from the cache or from the server
The records are returned in the same order as the `record_ids` parameter.
"""

if self._record_cache is None:
self._assert_online()
existing_records = []
records_tofetch = record_ids
else:
existing_records = self._record_cache.get_records(record_ids, record_type)
records_tofetch = set(record_ids) - {x.id for x in existing_records}

if self.offline and records_tofetch:
raise RuntimeError("Need to fetch some records, but not connected to a client")

if records_tofetch:
recs = self._client._get_records_by_type(record_type, list(records_tofetch))
if self._record_cache is not None:
uids = self._record_cache.update_records(recs)

for u, r in zip(uids, recs):
r._record_cache = self._record_cache
r._record_cache_uid = u

existing_records += recs

# Return everything in the same order as the input
all_recs = {r.id: r for r in existing_records}
ret = [all_recs.get(rid, None) for rid in record_ids]
if None in ret:
missing_ids = set(record_ids) - set(all_recs.keys())
raise RuntimeError(
f"Not all records found either in the cache or on the server. Missing records: {missing_ids}"
)

return ret

def _assert_online(self):
"""Raises an exception if this record does not have an associated client"""
if self.offline:
Expand Down
23 changes: 12 additions & 11 deletions qcportal/qcportal/torsiondrive/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,18 @@ def _fetch_initial_molecules(self):
self.initial_molecules_ = self._client.get_molecules(self.initial_molecules_ids_)

def _fetch_optimizations(self):
self._assert_online()

self.optimizations_ = self._client.make_request(
"get",
f"api/v1/records/torsiondrive/{self.id}/optimizations",
List[TorsiondriveOptimization],
)

# Fetch optimization records from the server
# Always fetch optimization metadata if we can
if not self.offline or self.optimizations_ is None:
self._assert_online()
self.optimizations_ = self._client.make_request(
"get",
f"api/v1/records/torsiondrive/{self.id}/optimizations",
List[TorsiondriveOptimization],
)

# Fetch optimization records from the server or cache
opt_ids = [x.optimization_id for x in self.optimizations_]
opt_records = self._client.get_optimizations(opt_ids)
opt_records = self._get_child_records(opt_ids, OptimizationRecord)

self._optimizations_cache = {}
for td_opt, opt_record in zip(self.optimizations_, opt_records):
Expand Down Expand Up @@ -218,7 +219,7 @@ def _fetch_minimum_optimizations(self):
# Fetch optimization records from the server
opt_key_ids = list(min_opt_ids.items())
opt_ids = [x[1] for x in opt_key_ids]
opt_records = self._client.get_optimizations(opt_ids)
opt_records = self._get_child_records(opt_ids, OptimizationRecord)

self._minimum_optimizations_cache = {deserialize_key(x[0]): y for x, y in zip(opt_key_ids, opt_records)}

Expand Down

0 comments on commit 27c3425

Please sign in to comment.