Skip to content

Commit

Permalink
Includes - neb
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed Mar 24, 2024
1 parent 347bcb4 commit 5e010e0
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 61 deletions.
31 changes: 31 additions & 0 deletions qcfractal/qcfractal/components/neb/record_db_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from sqlalchemy import Column, Integer, ForeignKey, String, UniqueConstraint, Index, Boolean, event, DDL
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.orderinglist import ordering_list
Expand All @@ -11,6 +13,9 @@
from qcfractal.components.singlepoint.record_db_models import QCSpecificationORM, SinglepointRecordORM
from qcfractal.db_socket import BaseORM

if TYPE_CHECKING:
from typing import Dict, Any, Optional, Iterable


class NEBOptimizationsORM(BaseORM):
__tablename__ = "neb_optimizations"
Expand Down Expand Up @@ -131,6 +136,32 @@ class NEBRecordORM(BaseRecordORM):

_qcportal_model_excludes = [*BaseRecordORM._qcportal_model_excludes, "specification_id"]

def model_dict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
d = BaseRecordORM.model_dict(self, exclude)

# Return initial molecule or just the ids, depending on what we have
if "initial_chain" in d:
init_chain = d.pop("initial_chain")
d["initial_chain_molecule_ids"] = [x["molecule_id"] for x in init_chain]
if "molecule" in init_chain[0]:
d["initial_chain"] = [x["molecule"] for x in init_chain]

if "optimizations" in d:
optimizations = d.pop("optimizations")

opt_dict = {}
for opt in optimizations:
if opt["ts"]:
opt_dict["transition"] = opt
elif opt["position"] == 0:
opt_dict["initial"] = opt
else:
opt_dict["final"] = opt

d["optimizations"] = opt_dict

return d

@property
def short_description(self) -> str:
n_mols = len(self.initial_chain)
Expand Down
25 changes: 13 additions & 12 deletions qcfractal/qcfractal/components/neb/record_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from sqlalchemy import select, func
from sqlalchemy.dialects.postgresql import insert, array_agg, aggregate_order_by, DOUBLE_PRECISION, TEXT
from sqlalchemy.orm import lazyload, joinedload, defer, undefer
from sqlalchemy.orm import lazyload, joinedload, selectinload, defer, undefer

from qcfractal.components.molecules.db_models import MoleculeORM
from qcfractal.components.services.db_models import ServiceQueueORM, ServiceDependencyORM
Expand Down Expand Up @@ -536,13 +536,24 @@ def get(
*,
session: Optional[Session] = None,
) -> List[Optional[Dict[str, Any]]]:
options = []

if include:
if "**" in include or "initial_chain" in include:
options.append(selectinload(NEBRecordORM.initial_chain).joinedload(NEBInitialchainORM.molecule))
if "**" in include or "singlepoints" in include:
options.append(selectinload(NEBRecordORM.singlepoints))
if "**" in include or "optimizations" in include:
options.append(selectinload(NEBRecordORM.optimizations))

with self.root_socket.optional_session(session, True) as session:
return self.root_socket.records.get_base(
orm_type=self.record_orm,
record_ids=record_ids,
include=include,
exclude=exclude,
missing_ok=missing_ok,
additional_options=options,
session=session,
)

Expand Down Expand Up @@ -875,17 +886,7 @@ def get_optimizations(
if rec is None:
raise MissingDataError(f"Cannot find record {record_id}")

ret = {}

for opt in rec.optimizations:
if opt.ts:
ret["transition"] = opt.model_dict()
elif opt.position == 0:
ret["initial"] = opt.model_dict()
else:
ret["final"] = opt.model_dict()

return ret
return rec.model_dict()["optimizations"]

def get_neb_result(
self,
Expand Down
105 changes: 58 additions & 47 deletions qcportal/qcportal/neb/record_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import List, Optional, Union, Dict, Iterable

try:
Expand All @@ -6,6 +8,7 @@
from pydantic import BaseModel, Field, Extra, root_validator, constr, validator, PrivateAttr
from typing_extensions import Literal

from qcportal.cache import get_records_with_cache
from qcportal.molecules import Molecule
from qcportal.record_models import BaseRecord, RecordAddBodyBase, RecordQueryFilters
from qcportal.utils import recursive_normalizer
Expand Down Expand Up @@ -135,7 +138,6 @@ class NEBRecord(BaseRecord):
initial_chain_molecule_ids_: Optional[List[int]] = Field(None, alias="initial_chain_molecule_ids")
singlepoints_: Optional[List[NEBSinglepoint]] = Field(None, alias="singlepoints")
optimizations_: Optional[Dict[str, NEBOptimization]] = Field(None, alias="optimizations")
ts_hessian_: Optional[SinglepointRecord] = Field(None, alias="ts_hessian")
neb_result_: Optional[Molecule] = Field(None, alias="neb_result")
initial_chain_: Optional[List[Molecule]] = Field(None, alias="initial_chain")

Expand All @@ -144,6 +146,7 @@ class NEBRecord(BaseRecord):
########################################
_optimizations_cache: Optional[Dict[str, OptimizationRecord]] = PrivateAttr(None)
_singlepoints_cache: Optional[Dict[int, List[SinglepointRecord]]] = PrivateAttr(None)
_ts_hessian: Optional[SinglepointRecord] = PrivateAttr(None)

def propagate_client(self, client):
BaseRecord.propagate_client(self, client)
Expand All @@ -157,73 +160,81 @@ def propagate_client(self, client):
for sp2 in splist:
sp2.propagate_client(client)

def fetch_all(self):
BaseRecord.fetch_all(self)

self._fetch_initial_chain()
self._fetch_singlepoints()
self._fetch_optimizations()
self._fetch_neb_result()

for opt in self._optimizations_cache.values():
opt.fetch_all()

for splist in self._singlepoints_cache.values():
for sp2 in splist:
sp2.fetch_all()
@classmethod
def _fetch_children_multi(cls, client, record_cache, records: Iterable[NEBRecord], recursive: bool):
# Should be checked by the calling function
assert records
assert all(isinstance(x, NEBRecord) for x in records)

# Collect optimization and singlepoint ids for all NEB
opt_ids = set()
sp_ids = set()

for r in records:
if r.optimizations_ is not None:
opt_ids.update(x.optimization_id for x in r.optimizations_.values())
if r.singlepoints_ is not None:
sp_ids.update(x.singlepoint_id for x in r.singlepoints_)

include = ["**"] if recursive else None
sp_ids = list(sp_ids)
opt_ids = list(opt_ids)

sp_records = get_records_with_cache(client, record_cache, sp_ids, SinglepointRecord, include=include)
opt_records = get_records_with_cache(client, record_cache, opt_ids, OptimizationRecord, include=include)

sp_map = {r.id: r for r in sp_records}
opt_map = {r.id: r for r in opt_records}

for r in records:
if r.optimizations_ is None:
r._optimizations_cache = None
else:
r._optimizations_cache = dict()
for opt_key, opt_info in r.optimizations_.items():
r._optimizations_cache[opt_key] = opt_map[opt_info.optimization_id]

if r.singlepoints_ is None:
r._singlepoints_cache = None
else:
r._singlepoints_cache = dict()
for sp_info in r.singlepoints_:
r._singlepoints_cache.setdefault(sp_info.chain_iteration, list())
r._singlepoints_cache[sp_info.chain_iteration].append(sp_map[sp_info.singlepoint_id])

if len(r._singlepoints_cache) > 0:
if len(r._singlepoints_cache[max(r._singlepoints_cache)]) == 1:
_, temp_list = r._singlepoints_cache.popitem()
r._ts_hessian = temp_list[0]
assert r._ts_hessian.specification.driver == "hessian"

r.propagate_client(r._client)

def _fetch_optimizations(self):
self._assert_online()

if not self.offline or self.optimizations_ is None:
if self.optimizations_ is None:
self._assert_online()
self.optimizations_ = self._client.make_request(
"get",
f"api/v1/records/neb/{self.id}/optimizations",
Dict[str, NEBOptimization],
)

# Fetch optimization records from server
opt_ids = [opt.optimization_id for opt in self.optimizations_.values()]
opt_recs = self._get_child_records(opt_ids, OptimizationRecord)
opt_map = {opt.id: opt for opt in opt_recs}

self._optimizations_cache = {}

for opt_key, opt_info in self.optimizations_.items():
self._optimizations_cache[opt_key] = opt_map[opt_info.optimization_id]

self.propagate_client(self._client)
self.fetch_children(False)

def _fetch_singlepoints(self):
self._assert_online()

if not self.offline or self.singlepoints_ is None:
if self.singlepoints_ is None:
self._assert_online()
self.singlepoints_ = self._client.make_request(
"get",
f"api/v1/records/neb/{self.id}/singlepoints",
List[NEBSinglepoint],
)

# Fetch singlepoint records from server or the cache
sp_ids = [sp.singlepoint_id for sp in self.singlepoints_]
sp_recs = self._get_child_records(sp_ids, SinglepointRecord)

self._singlepoints_cache = {}

# Singlepoints should be in order of (iteration, position)
for sp_info, sp_rec in zip(self.singlepoints_, sp_recs):
self._singlepoints_cache.setdefault(sp_info.chain_iteration, list())
self._singlepoints_cache[sp_info.chain_iteration].append(sp_rec)

if len(self._singlepoints_cache) > 0:
if len(self._singlepoints_cache[max(self._singlepoints_cache)]) == 1:
_, temp_list = self._singlepoints_cache.popitem()
self.ts_hessian_ = temp_list[0]
assert self.ts_hessian_.specification.driver == "hessian"

self.propagate_client(self._client)
self.fetch_children(False)

def _fetch_initial_chain(self):
self._assert_online()
Expand Down Expand Up @@ -281,4 +292,4 @@ def ts_optimization(self) -> Optional[OptimizationRecord]:
def ts_hessian(self) -> Optional[SinglepointRecord]:
if self._singlepoints_cache is None:
self._fetch_singlepoints()
return self.ts_hessian_
return self._ts_hessian
21 changes: 19 additions & 2 deletions qcportal/qcportal/neb/test_record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
all_includes = ["initial_chain", "singlepoints", "optimizations"]


@pytest.mark.parametrize("includes", [None, all_includes])
@pytest.mark.parametrize("includes", [None, ["**"], all_includes])
def test_neb_record_model(snowflake: QCATestingSnowflake, includes: Optional[List[str]]):
storage_socket = snowflake.get_storage_socket()
snowflake_client = snowflake.client()
Expand All @@ -25,9 +25,26 @@ def test_neb_record_model(snowflake: QCATestingSnowflake, includes: Optional[Lis
record = snowflake_client.get_nebs(rec_id, include=includes)

if includes is not None:
record._client = None
assert record.initial_chain_ is not None
assert record.initial_chain_molecule_ids_ is not None
assert record.singlepoints_ is not None
assert record.optimizations_ is not None
record.propagate_client(None)
assert record.offline

# children have all data fetched
assert all(x.initial_molecule_ is not None for x in record.optimizations.values())
assert all(x.comments_ is not None for x in record.optimizations.values())

for sp in record.singlepoints.values():
assert all(x.molecule_ is not None for x in sp)
assert all(x.comments_ is not None for x in sp)
else:
assert record.initial_chain_ is None
assert record.initial_chain_molecule_ids_ is None
assert record.singlepoints_ is None
assert record.optimizations_ is None

assert record.id == rec_id
assert record.status == RecordStatusEnum.complete

Expand Down

0 comments on commit 5e010e0

Please sign in to comment.