From f256b8410b0b713b1cf551aed71682e93e5cd27e Mon Sep 17 00:00:00 2001 From: Benjamin Pritchard Date: Tue, 27 Feb 2024 14:50:41 -0500 Subject: [PATCH] Fix ordering bug in reactions --- .../components/reaction/record_socket.py | 10 ++++++---- qcportal/qcportal/reaction/test_record_models.py | 16 +++++++++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/qcfractal/qcfractal/components/reaction/record_socket.py b/qcfractal/qcfractal/components/reaction/record_socket.py index e178eaefe..2100b179c 100644 --- a/qcfractal/qcfractal/components/reaction/record_socket.py +++ b/qcfractal/qcfractal/components/reaction/record_socket.py @@ -122,8 +122,6 @@ def iterate_service( # of an optimization) required_sp_mols = {x.molecule_id for x in rxn_orm.components if has_singlepoint} - complete_tasks = service_orm.dependencies - # What was already completed and/or submitted sub_opt_mols = {x.molecule_id for x in rxn_orm.components if x.optimization_id is not None} sub_sp_mols = {x.molecule_id for x in rxn_orm.components if x.singlepoint_id is not None} @@ -135,6 +133,10 @@ def iterate_service( # Singlepoint calculations must wait for optimizations sp_mols_to_compute -= opt_mols_to_compute + # Convert to well-ordered lists + opt_mols_to_compute = list(opt_mols_to_compute) + sp_mols_to_compute = list(sp_mols_to_compute) + service_orm.dependencies = [] output = "" @@ -168,11 +170,11 @@ def iterate_service( if sp_mols_to_compute: # If an optimization was specified, we need to get the final molecule from that if has_optimization: - real_mols_to_compute = { + real_mols_to_compute = [ x.optimization_record.final_molecule_id for x in rxn_orm.components if x.molecule_id in sp_mols_to_compute - } + ] else: real_mols_to_compute = sp_mols_to_compute diff --git a/qcportal/qcportal/reaction/test_record_models.py b/qcportal/qcportal/reaction/test_record_models.py index e1b65a293..fa4f7b83a 100644 --- a/qcportal/qcportal/reaction/test_record_models.py +++ b/qcportal/qcportal/reaction/test_record_models.py @@ -15,14 +15,15 @@ @pytest.mark.parametrize("includes", [None, all_includes]) -def test_reactionrecord_model(snowflake: QCATestingSnowflake, includes: Optional[List[str]]): +@pytest.mark.parametrize("testfile", ["rxn_H2O_psi4_mp2_optsp", "rxn_H2O_psi4_mp2_opt", "rxn_H2O_psi4_b3lyp_sp"]) +def test_reactionrecord_model(snowflake: QCATestingSnowflake, includes: Optional[List[str]], testfile: str): storage_socket = snowflake.get_storage_socket() snowflake_client = snowflake.client() activated_manager_name, _ = snowflake.activate_manager() - input_spec, stoichiometry, results = load_test_data("rxn_H2O_psi4_mp2_optsp") + input_spec, stoichiometry, results = load_test_data(testfile) - rec_id = run_test_data(storage_socket, activated_manager_name, "rxn_H2O_psi4_mp2_optsp") + rec_id = run_test_data(storage_socket, activated_manager_name, testfile) record = snowflake_client.get_reactions(rec_id, include=includes) if includes is not None: @@ -42,6 +43,15 @@ def test_reactionrecord_model(snowflake: QCATestingSnowflake, includes: Optional for c in com: if c.singlepoint_id is not None: + # Molecule id may represent the initial molecule for the optimization, not + # necessarily the single point calculation + if c.optimization_id is None: + assert c.singlepoint_record.molecule.id == c.molecule_id + else: + assert c.singlepoint_record.molecule.id == c.optimization_record.final_molecule.id + + assert list(c.singlepoint_record.molecule.symbols) == list(c.molecule.symbols) assert c.singlepoint_record.id == c.singlepoint_id if c.optimization_id is not None: + assert c.optimization_record.initial_molecule.id == c.molecule_id assert c.optimization_record.id == c.optimization_id