Skip to content

Commit

Permalink
Merge pull request #881 from MolSSI/ds_submit_bkg
Browse files Browse the repository at this point in the history
Add ability to submit dataset records as a background job
  • Loading branch information
bennybp authored Jan 21, 2025
2 parents 260aca0 + 282cf1a commit 78a3b15
Show file tree
Hide file tree
Showing 21 changed files with 563 additions and 81 deletions.
16 changes: 16 additions & 0 deletions qcfractal/qcfractal/components/dataset_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,22 @@ def submit_dataset_v1(dataset_type: str, dataset_id: int, body_data: DatasetSubm
)


@api_v1.route("/datasets/<string:dataset_type>/<int:dataset_id>/background_submit", methods=["POST"])
@wrap_route("WRITE")
def background_submit_dataset_v1(dataset_type: str, dataset_id: int, body_data: DatasetSubmitBody):
ds_socket = storage_socket.datasets.get_socket(dataset_type)
return ds_socket.background_submit(
dataset_id,
entry_names=body_data.entry_names,
specification_names=body_data.specification_names,
tag=body_data.tag,
priority=body_data.priority,
owner_user=g.username,
owner_group=body_data.owner_group,
find_existing=body_data.find_existing,
)


###################
# Specifications
###################
Expand Down
239 changes: 204 additions & 35 deletions qcfractal/qcfractal/components/dataset_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import load_only, lazyload, joinedload, with_polymorphic
from sqlalchemy.orm.attributes import flag_modified
from qcfractal.db_socket.helpers import get_count

from qcfractal.components.dataset_db_models import (
BaseDatasetORM,
Expand All @@ -26,7 +27,7 @@
from qcportal.dataset_models import DatasetAttachmentType
from qcportal.exceptions import AlreadyExistsError, MissingDataError, UserReportableError
from qcportal.internal_jobs import InternalJobStatusEnum
from qcportal.metadata_models import InsertMetadata, DeleteMetadata, UpdateMetadata
from qcportal.metadata_models import InsertMetadata, DeleteMetadata, UpdateMetadata, InsertCountsMetadata
from qcportal.record_models import RecordStatusEnum, PriorityEnum
from qcportal.utils import chunk_iterable, now_at_utc

Expand Down Expand Up @@ -123,7 +124,7 @@ def _submit(
owner_user_id: Optional[int],
owner_group_id: Optional[int],
find_existing: bool,
):
) -> InsertCountsMetadata:
raise NotImplementedError("_submit must be overridden by the derived class")

def get_submit_info(
Expand Down Expand Up @@ -1206,8 +1207,9 @@ def submit(
owner_group: Optional[Union[int, str]],
find_existing: bool,
*,
job_progress: Optional[JobProgress] = None,
session: Optional[Session] = None,
):
) -> InsertCountsMetadata:
"""
Submit computations for this dataset
Expand All @@ -1232,11 +1234,23 @@ def submit(
Group with additional permission for these records
find_existing
If True, search for existing records and return those. If False, always add new records
job_progress
Object used to track progress if this function is being run in a background job
session
An existing SQLAlchemy session to use. If None, one will be created. If an existing session
is used, it will be flushed (but not committed) before returning from this function.
Returns
-------
:
Counts of how many records were inserted or already existing. This only applies to records - existing
records already part of this dataset (ie, a given entry/specification pair already has a record)
is not counted as existing in the return value.
"""

n_inserted = 0
n_existing = 0

with self.root_socket.optional_session(session) as session:
tag, priority, owner_user_id, owner_group_id = self.get_submit_info(
dataset_id, tag, priority, owner_user, owner_group, session=session
Expand All @@ -1247,7 +1261,7 @@ def submit(
################################
stmt = select(self.specification_orm)

# We want the actual optimization specification as well
# We want the actual full specification as well
stmt = stmt.join(self.specification_orm.specification)
stmt = stmt.where(self.specification_orm.dataset_id == dataset_id)
if specification_names is not None:
Expand All @@ -1265,46 +1279,166 @@ def submit(
################################
# Get entry details
################################
stmt = select(self.entry_orm)
stmt = stmt.where(self.entry_orm.dataset_id == dataset_id)
if entry_names is None:
# Do all entries in batches using server-side cursors
stmt = select(self.entry_orm)
stmt = stmt.where(self.entry_orm.dataset_id == dataset_id)

# for progress tracking
if job_progress is not None:
total_records = len(ds_specs) * get_count(session, stmt)
records_done = 0

r = session.execute(stmt).scalars()

while entries_batch := r.fetchmany(500):
entries_batch_names = [e.name for e in entries_batch]

# Find which records/record_items already exist
stmt = select(self.record_item_orm.entry_name, self.record_item_orm.specification_name)
stmt = stmt.where(self.record_item_orm.dataset_id == dataset_id)

stmt = stmt.where(self.record_item_orm.entry_name.in_(entries_batch_names))
if specification_names is not None:
stmt = stmt.where(self.record_item_orm.specification_name.in_(specification_names))

existing_records = session.execute(stmt).all()

batch_meta = self._submit(
session,
dataset_id,
entries_batch,
ds_specs,
existing_records,
tag,
priority,
owner_user_id,
owner_group_id,
find_existing,
)

if entry_names is not None:
stmt = stmt.where(self.entry_orm.name.in_(entry_names))
n_inserted += batch_meta.n_inserted
n_existing += batch_meta.n_existing

entries = session.execute(stmt).scalars().all()
if job_progress is not None:
job_progress.raise_if_cancelled()
records_done += len(entries_batch)
job_progress.update_progress(100 * (records_done * len(ds_specs)) / total_records)

# Check to make sure we found all the entries
if entry_names is not None:
found_entries = {x.name for x in entries}
missing_entries = set(entry_names) - found_entries
if missing_entries:
raise MissingDataError(f"Could not find all entries. Missing: {missing_entries}")
else: # entry names were given

# Find which records/record_items already exist
stmt = select(self.record_item_orm)
stmt = stmt.where(self.record_item_orm.dataset_id == dataset_id)
# for progress tracking
if job_progress is not None:
total_records = len(ds_specs) * len(entry_names)
records_done = 0

if entry_names is not None:
stmt = stmt.where(self.record_item_orm.entry_name.in_(entry_names))
if specification_names is not None:
stmt = stmt.where(self.record_item_orm.specification_name.in_(specification_names))
# For checking for missing entries
found_entries = []

existing_record_orm = session.execute(stmt).scalars().all()
existing_records = [(x.entry_name, x.specification_name) for x in existing_record_orm]
# Do entries in batches via the given entry names (in batches)
for entries_names_batch in chunk_iterable(entry_names, 500):
stmt = select(self.entry_orm)
stmt = stmt.where(self.entry_orm.dataset_id == dataset_id)
stmt = stmt.where(self.entry_orm.name.in_(entries_names_batch))

return self._submit(
session,
dataset_id,
entries,
ds_specs,
existing_records,
tag,
priority,
owner_user_id,
owner_group_id,
find_existing,
entries_batch = session.execute(stmt).scalars().all()

entries_batch_names = [e.name for e in entries_batch]
found_entries.extend(entries_batch_names)

# Find which records/record_items already exist
stmt = select(self.record_item_orm.entry_name, self.record_item_orm.specification_name)
stmt = stmt.where(self.record_item_orm.dataset_id == dataset_id)

stmt = stmt.where(self.record_item_orm.entry_name.in_(entries_batch_names))
if specification_names is not None:
stmt = stmt.where(self.record_item_orm.specification_name.in_(specification_names))

existing_records = session.execute(stmt).all()

batch_meta = self._submit(
session,
dataset_id,
entries_batch,
ds_specs,
existing_records,
tag,
priority,
owner_user_id,
owner_group_id,
find_existing,
)

n_inserted += batch_meta.n_inserted
n_existing += batch_meta.n_existing

if job_progress is not None:
job_progress.raise_if_cancelled()
records_done += len(entries_names_batch)
job_progress.update_progress(100 * (records_done * len(ds_specs)) / total_records)

if entry_names is not None:
missing_entries = set(entry_names) - set(found_entries)
if missing_entries:
raise MissingDataError(f"Could not find all entries. Missing: {missing_entries}")

return InsertCountsMetadata(n_inserted=n_inserted, n_existing=n_existing)

def background_submit(
self,
dataset_id: int,
entry_names: Optional[Iterable[str]],
specification_names: Optional[Iterable[str]],
tag: Optional[str],
priority: Optional[PriorityEnum],
owner_user: Optional[Union[int, str]],
owner_group: Optional[Union[int, str]],
find_existing: bool,
*,
session: Optional[Session] = None,
) -> int:
"""
Submit computations for this dataset as an internal job
This creates an internal job for the submission and returns the ID
See :ref:`submit` for details for the rest of the details on functionality and parameters.
Returns
-------
:
ID of the created internal job
"""

with self.root_socket.optional_session(session) as session:
job_id = self.root_socket.internal_jobs.add(
f"dataset_submit_{dataset_id}",
now_at_utc(),
f"datasets.submit",
{
"dataset_id": dataset_id,
"entry_names": entry_names,
"specification_names": specification_names,
"tag": tag,
"priority": priority,
"owner_user": owner_user,
"owner_group": owner_group,
"find_existing": find_existing,
},
user_id=None,
unique_name=False,
serial_group=f"ds_submit_{dataset_id}", # only run one submission for this dataset at a time
session=session,
)

stmt = (
insert(DatasetInternalJobORM)
.values(dataset_id=dataset_id, internal_job_id=job_id)
.on_conflict_do_nothing()
)
session.execute(stmt)
return job_id

#######################
# Record modification
#######################
Expand Down Expand Up @@ -1941,3 +2075,38 @@ def add_create_view_attachment_job(
)
session.execute(stmt)
return job_id

def submit(
self,
dataset_id: int,
entry_names: Optional[Iterable[str]],
specification_names: Optional[Iterable[str]],
tag: Optional[str],
priority: Optional[PriorityEnum],
owner_user: Optional[Union[int, str]],
owner_group: Optional[Union[int, str]],
find_existing: bool,
*,
session: Optional[Session] = None,
):
"""
Submit computations for a dataset
This function looks up the dataset socket and then call submit on that socket
"""

with self.root_socket.optional_session(session) as session:
ds_type = self.lookup_type(dataset_id)
ds_socket = self.get_socket(ds_type)

return ds_socket.submit(
dataset_id=dataset_id,
entry_names=entry_names,
specification_names=specification_names,
tag=tag,
priority=priority,
owner_user=owner_user,
owner_group=owner_group,
find_existing=find_existing,
session=session,
)
13 changes: 11 additions & 2 deletions qcfractal/qcfractal/components/gridoptimization/dataset_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from qcfractal.components.dataset_socket import BaseDatasetSocket
from qcfractal.components.gridoptimization.record_db_models import GridoptimizationRecordORM
from qcportal.gridoptimization import GridoptimizationDatasetNewEntry, GridoptimizationSpecification
from qcportal.metadata_models import InsertMetadata, InsertCountsMetadata
from qcportal.record_models import PriorityEnum
from .dataset_db_models import (
GridoptimizationDatasetORM,
Expand All @@ -17,7 +18,6 @@

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
from qcportal.metadata_models import InsertMetadata
from qcfractal.db_socket.socket import SQLAlchemySocket
from typing import Optional, Sequence, Iterable, Tuple

Expand Down Expand Up @@ -76,7 +76,11 @@ def _submit(
owner_user_id: Optional[int],
owner_group_id: Optional[int],
find_existing: bool,
):
) -> InsertCountsMetadata:

n_inserted = 0
n_existing = 0

for spec in spec_orm:
goopt_spec_obj = spec.specification.to_model(GridoptimizationSpecification)
goopt_spec_input_dict = goopt_spec_obj.dict()
Expand Down Expand Up @@ -113,3 +117,8 @@ def _submit(
record_id=gridopt_ids[0],
)
session.add(rec)

n_inserted += meta.n_inserted
n_existing += meta.n_existing

return InsertCountsMetadata(n_inserted=n_inserted, n_existing=n_existing)
3 changes: 2 additions & 1 deletion qcfractal/qcfractal/components/internal_jobs/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from qcfractal.db_socket.helpers import get_query_proj_options
from qcportal.exceptions import MissingDataError
from qcportal.internal_jobs.models import InternalJobStatusEnum, InternalJobQueryFilters
from qcportal.serialization import encode_to_json
from qcportal.utils import now_at_utc
from .db_models import InternalJobORM
from .status import JobProgress, CancelledJobException, JobRunnerStoppingException
Expand Down Expand Up @@ -361,7 +362,7 @@ def _run_single(self, session: Session, job_orm: InternalJobORM, logger, job_pro
job_orm.status = InternalJobStatusEnum.complete
job_orm.progress = 100
job_orm.progress_description = "Complete"
job_orm.result = result
job_orm.result = encode_to_json(result)

# Job itself is being cancelled
except CancelledJobException:
Expand Down
Loading

0 comments on commit 78a3b15

Please sign in to comment.