diff --git a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py index 198499693..b1885a5fc 100644 --- a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py +++ b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py @@ -1435,6 +1435,39 @@ def _assert_total_charge(system, atom_classes, chgA, chgB): assert chgB == pytest.approx(np.sum(stateB_charges)) +def test_dry_run_alchemwater_solvent(benzene_to_benzoic_mapping, tmpdir): + stateA_system = openfe.ChemicalSystem( + {'ligand': benzene_to_benzoic_mapping.componentA, + 'solvent': openfe.SolventComponent()} + ) + stateB_system = openfe.ChemicalSystem( + {'ligand': benzene_to_benzoic_mapping.componentB, + 'solvent': openfe.SolventComponent()} + ) + solv_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + protocol = openmm_rfe.RelativeHybridTopologyProtocol( + settings=solv_settings, + ) + + # create DAG from protocol and take first (and only) work unit from within + dag = protocol.create( + stateA=stateA_system, + stateB=stateB_system, + mapping={'ligand': benzene_to_benzoic_mapping}, + ) + unit = list(dag.protocol_units)[0] + + with tmpdir.as_cwd(): + sampler = unit.run(dry=True)['debug']['sampler'] + htf = sampler._factory + _assert_total_charge(htf.hybrid_system, + htf._atom_classes, 0, 0) + + assert len(htf._atom_classes['core_atoms']) == 14 + assert len(htf._atom_classes['unique_new_atoms']) == 3 + assert len(htf._atom_classes['unique_old_atoms']) == 1 + + @pytest.mark.slow @pytest.mark.parametrize('mapping_name,chgA,chgB,correction,core_atoms,new_uniq,old_uniq', [ ['benzene_to_aniline_mapping', 0, 1, False, 11, 4, 1], @@ -1445,19 +1478,21 @@ def _assert_total_charge(system, atom_classes, chgA, chgB): ['benzoic_to_benzene_mapping', 0, 0, True, 14, 1, 3], ['benzoic_to_benzene_mapping', 0, 1, False, 11, 1, 3], ]) -def test_dry_run_alchemwater_totcharge( +def test_dry_run_complex_alchemwater_totcharge( mapping_name, chgA, chgB, correction, core_atoms, - new_uniq, old_uniq, tmpdir, request, + new_uniq, old_uniq, tmpdir, request, T4_protein_component, ): mapping = request.getfixturevalue(mapping_name) stateA_system = openfe.ChemicalSystem( {'ligand': mapping.componentA, - 'solvent': openfe.SolventComponent()} + 'solvent': openfe.SolventComponent(), + 'protein': T4_protein_component} ) stateB_system = openfe.ChemicalSystem( {'ligand': mapping.componentB, - 'solvent': openfe.SolventComponent()} + 'solvent': openfe.SolventComponent(), + 'protein': T4_protein_component} ) solv_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings()