Skip to content

Commit

Permalink
Allow for only getting minimum optimizations for a torsiondrive
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed Mar 26, 2024
1 parent 5e010e0 commit 07e48df
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 49 deletions.
53 changes: 52 additions & 1 deletion qcfractal/qcfractal/components/torsiondrive/record_db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,21 @@

from typing import TYPE_CHECKING

from sqlalchemy import select, Column, Integer, ForeignKey, String, UniqueConstraint, Index, CheckConstraint, event, DDL
from sqlalchemy import func
from sqlalchemy import (
select,
Integer,
ForeignKey,
String,
UniqueConstraint,
Index,
CheckConstraint,
event,
DDL,
Column,
TEXT,
DOUBLE_PRECISION,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.orderinglist import ordering_list
from sqlalchemy.orm import relationship, column_property
Expand Down Expand Up @@ -96,6 +110,36 @@ def short_description(self) -> str:
return f"{self.program}~{self.optimization_specification.short_description}"


# CTE for a table with minimimum optimizations. Has columns torsiondrive_id, key, and (minimum) optimization_id
# Chooses the optimization with the lowest energy, and if there are multiple, the one with the lowest id
_minopt_cte = (
select(
TorsiondriveOptimizationORM.torsiondrive_id.label("torsiondrive_id"),
TorsiondriveOptimizationORM.key.label("key"),
TorsiondriveOptimizationORM.optimization_id.label("optimization_id"),
)
.join(OptimizationRecordORM, TorsiondriveOptimizationORM.optimization_id == OptimizationRecordORM.id)
.distinct(TorsiondriveOptimizationORM.torsiondrive_id, TorsiondriveOptimizationORM.key)
.order_by(
TorsiondriveOptimizationORM.torsiondrive_id,
TorsiondriveOptimizationORM.key,
OptimizationRecordORM.energies[-1].cast(TEXT).cast(DOUBLE_PRECISION).asc(),
OptimizationRecordORM.id.asc(),
)
.cte()
)

# CTE for a table with minimimum optimizations, but as JSON. Has columns torsiondrive_id, minimum_optimizations (as JSONB)
_minopt_cte_agg = (
select(
_minopt_cte.c.torsiondrive_id.label("torsiondrive_id"),
func.jsonb_object_agg(_minopt_cte.c.key, _minopt_cte.c.optimization_id).label("minimum_optimizations"),
)
.group_by(_minopt_cte.c.torsiondrive_id)
.cte()
)


class TorsiondriveRecordORM(BaseRecordORM):
"""
Table for storing torsiondrive calculations
Expand All @@ -118,6 +162,13 @@ class TorsiondriveRecordORM(BaseRecordORM):
passive_deletes=True,
)

minimum_optimizations = column_property(
select(_minopt_cte_agg.c.minimum_optimizations)
.where(id == _minopt_cte_agg.c.torsiondrive_id)
.scalar_subquery(),
deferred=True,
)

__mapper_args__ = {
"polymorphic_identity": "torsiondrive",
}
Expand Down
39 changes: 7 additions & 32 deletions qcfractal/qcfractal/components/torsiondrive/record_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
except ImportError:
from pydantic import BaseModel, Extra
from sqlalchemy import select, func
from sqlalchemy.dialects.postgresql import insert, array_agg, aggregate_order_by, DOUBLE_PRECISION, TEXT
from sqlalchemy.dialects.postgresql import insert, array_agg, aggregate_order_by
from sqlalchemy.orm import lazyload, selectinload, joinedload, defer, undefer

from qcfractal.components.optimization.record_db_models import (
OptimizationSpecificationORM,
OptimizationRecordORM,
)
from qcfractal.components.services.db_models import ServiceQueueORM, ServiceDependencyORM
from qcfractal.components.singlepoint.record_db_models import QCSpecificationORM
Expand Down Expand Up @@ -416,6 +415,9 @@ def get(
if "**" in include or "optimizations" in include:
options.append(selectinload(TorsiondriveRecordORM.optimizations))

if "**" in include or "minimum_optimizations" in include:
options.append(undefer(TorsiondriveRecordORM.minimum_optimizations))

with self.root_socket.optional_session(session, True) as session:
return self.root_socket.records.get_base(
orm_type=self.record_orm,
Expand Down Expand Up @@ -803,35 +805,8 @@ def get_minimum_optimizations(
optimization in the torsiondrive (representing the angles)
"""

# Kind of complicated, but this is relatively cross platform
# (with postgres, could do a DISTINCT ON)

# CTE with columns torsiondrive_id, key, min_energy
energy_cte = (
select(
TorsiondriveOptimizationORM.torsiondrive_id.label("torsiondrive_id"),
TorsiondriveOptimizationORM.key.label("key"),
func.min(OptimizationRecordORM.energies[-1].cast(TEXT).cast(DOUBLE_PRECISION)).label("min_energy"),
)
.join(OptimizationRecordORM)
.group_by("torsiondrive_id", "key")
).cte()

# Select rows with matching minimum energies
# We order by the optimization id desc to handle the case where multiple records with same final energy
# Then, the dictionary comprehension at the end last of that energy (lowest id)
stmt = (
select(TorsiondriveOptimizationORM.key, TorsiondriveOptimizationORM.optimization_id)
.join(energy_cte, energy_cte.c.torsiondrive_id == TorsiondriveOptimizationORM.torsiondrive_id)
.join(TorsiondriveOptimizationORM.optimization_record)
.where(TorsiondriveOptimizationORM.torsiondrive_id == record_id)
.where(TorsiondriveOptimizationORM.key == energy_cte.c.key)
.where(OptimizationRecordORM.energies[-1].cast(TEXT).cast(DOUBLE_PRECISION) == energy_cte.c.min_energy)
.order_by(OptimizationRecordORM.id.desc())
)
stmt = select(TorsiondriveRecordORM.minimum_optimizations).where(TorsiondriveRecordORM.id == record_id)

with self.root_socket.optional_session(session, True) as session:
r = session.execute(stmt).all() # List of (key, id)

# If multiple records with the same energy are returned, then this will choose the last
return {x: y for x, y in r}
r = session.execute(stmt).scalar_one_or_none() # List of (key, id)
return {} if r is None else r
39 changes: 23 additions & 16 deletions qcportal/qcportal/torsiondrive/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ class TorsiondriveRecord(BaseRecord):
######################################################
initial_molecules_ids_: Optional[List[int]] = Field(None, alias="initial_molecules_ids")
initial_molecules_: Optional[List[Molecule]] = Field(None, alias="initial_molecules")

optimizations_: Optional[List[TorsiondriveOptimization]] = Field(None, alias="optimizations")
minimum_optimizations_: Optional[Dict[str, int]] = Field(None, alias="minimum_optimizations")

########################################
# Caches
Expand All @@ -156,31 +158,43 @@ def _fetch_children_multi(cls, client, record_cache, records: Iterable[Torsiondr
for r in records:
if r.optimizations_:
opt_ids.update(x.optimization_id for x in r.optimizations_)
if r.minimum_optimizations_:
opt_ids.update(r.minimum_optimizations_.values())

include = ["**"] if recursive else None
opt_ids = list(opt_ids)
opt_records = get_records_with_cache(client, record_cache, opt_ids, OptimizationRecord, include=include)
opt_map = {x.id: x for x in opt_records}

for r in records:
if r.optimizations_ is None:
r._optimizations_cache = None
r._minimum_optimizations_cache = None
else:
r._optimizations_cache = None
r._minimum_optimizations_cache = None

if r.optimizations_ is None and r.minimum_optimizations_ is None:
continue

if r.optimizations_ is not None:
r._optimizations_cache = {}
for td_opt in r.optimizations_:
key = deserialize_key(td_opt.key)
r._optimizations_cache.setdefault(key, list())
r._optimizations_cache[key].append(opt_map[td_opt.optimization_id])

# find the minimum optimizations for each key
if r.minimum_optimizations_ is None and r.optimizations_ is not None:
# find the minimum optimizations for each key from what we have in the optimizations
# chooses the lowest id if there are records with the same energy
r._minimum_optimizations_cache = {}
r.minimum_optimizations_ = {}
for k, v in r._optimizations_cache.items():
# Remove any optimizations without energies
v2 = [x for x in v if x.energies]
if v2:
r._minimum_optimizations_cache[k] = min(v2, key=lambda x: (x.energies[-1], x.id))
lowest_opt = min(v2, key=lambda x: (x.energies[-1], x.id))
r.minimum_optimizations_[serialize_key(k)] = lowest_opt.id

if r.minimum_optimizations_ is not None: # either from the server or from above
r._minimum_optimizations_cache = {
deserialize_key(k): opt_map[v] for k, v in r.minimum_optimizations_.items()
}

r.propagate_client(r._client)

Expand Down Expand Up @@ -223,20 +237,13 @@ def _fetch_optimizations(self):
def _fetch_minimum_optimizations(self):
self._assert_online()

min_opt_ids = self._client.make_request(
self.minimum_optimizations_ = self._client.make_request(
"get",
f"api/v1/records/torsiondrive/{self.id}/minimum_optimizations",
Dict[str, int],
)

# Fetch optimization records from the server
opt_key_ids = list(min_opt_ids.items())
opt_ids = [x[1] for x in opt_key_ids]
opt_records = self._get_child_records(opt_ids, OptimizationRecord)

self._minimum_optimizations_cache = {deserialize_key(x[0]): y for x, y in zip(opt_key_ids, opt_records)}

self.propagate_client(self._client)
self.fetch_children(False)

@property
def initial_molecules(self) -> List[Molecule]:
Expand Down

0 comments on commit 07e48df

Please sign in to comment.