diff --git a/smee/mm/_fe.py b/smee/mm/_fe.py index 6e830da..9669f69 100644 --- a/smee/mm/_fe.py +++ b/smee/mm/_fe.py @@ -48,29 +48,50 @@ def _extract_pure_solvent( return system, xyz, box, beta, pressure, energy +def _to_solvent_dict( + solvent: smee.TensorTopology, n_solvent: int +) -> dict[str, int] | None: + if solvent is None: + return None + + smiles = openff.toolkit.Molecule.from_rdkit( + smee.mm._utils.topology_to_rdkit(solvent), + allow_undefined_stereo=True, + ).to_smiles(mapped=True) + + return {smiles: n_solvent} + + def generate_dg_solv_data( solute: smee.TensorTopology, - solvent: smee.TensorTopology, + solvent_a: smee.TensorTopology | None, + solvent_b: smee.TensorTopology | None, force_field: smee.TensorForceField, temperature: openmm.unit.Quantity = 298.15 * openmm.unit.kelvin, pressure: openmm.unit.Quantity = 1.0 * openmm.unit.atmosphere, - vacuum_protocol: typing.Optional["absolv.config.EquilibriumProtocol"] = None, - solvent_protocol: typing.Optional["absolv.config.EquilibriumProtocol"] = None, - n_solvent: int = 216, + solvent_a_protocol: typing.Optional["absolv.config.EquilibriumProtocol"] = None, + solvent_b_protocol: typing.Optional["absolv.config.EquilibriumProtocol"] = None, + n_solvent_a: int = 216, + n_solvent_b: int = 216, output_dir: pathlib.Path | None = None, ): """Run a solvation free energy calculation using ``absolv``, and saves the output such that a differentiable free energy can be computed. + The free energy will correspond to the free energy of transferring a solute from + solvent A to solvent B. + Args: solute: The solute topology. - solvent: The solvent topology. + solvent_a: The topology of solvent A, or ``None`` if solvent A is vacuum. + solvent_b: The topology of solvent B, or ``None`` if solvent B is vacuum. force_field: The force field to parameterize the system with. temperature: The temperature to simulate at. pressure: The pressure to simulate at. - vacuum_protocol: The protocol to use for the vacuum phase. - solvent_protocol: The protocol to use for the solvent phase. - n_solvent: The number of solvent molecules to use. + solvent_a_protocol: The protocol to use to decouple the solute in solvent A. + solvent_b_protocol: The protocol to use to decouple the solute in solvent B. + n_solvent_a: The number of solvent A molecules to use. + n_solvent_b: The number of solvent B molecules to use. output_dir: The directory to write the output FEP data to. """ import absolv.config @@ -80,65 +101,69 @@ def generate_dg_solv_data( output_dir = pathlib.Path.cwd() if output_dir is None else output_dir - if vacuum_protocol is None: - vacuum_protocol = absolv.config.EquilibriumProtocol( - production_protocol=absolv.config.HREMDProtocol( - n_steps_per_cycle=500, - n_cycles=2000, - integrator=femto.md.config.LangevinIntegrator( - timestep=1.0 * openmm.unit.femtosecond - ), - trajectory_interval=1, + vacuum_protocol = absolv.config.EquilibriumProtocol( + production_protocol=absolv.config.HREMDProtocol( + n_steps_per_cycle=500, + n_cycles=2000, + integrator=femto.md.config.LangevinIntegrator( + timestep=1.0 * openmm.unit.femtosecond ), - lambda_sterics=absolv.config.DEFAULT_LAMBDA_STERICS_VACUUM, - lambda_electrostatics=absolv.config.DEFAULT_LAMBDA_ELECTROSTATICS_VACUUM, - ) - if solvent_protocol is None: - solvent_protocol = absolv.config.EquilibriumProtocol( - production_protocol=absolv.config.HREMDProtocol( - n_steps_per_cycle=500, - n_cycles=1000, - integrator=femto.md.config.LangevinIntegrator( - timestep=4.0 * openmm.unit.femtosecond - ), - trajectory_interval=1, - trajectory_enforce_pbc=True, + trajectory_interval=1, + ), + lambda_sterics=absolv.config.DEFAULT_LAMBDA_STERICS_VACUUM, + lambda_electrostatics=absolv.config.DEFAULT_LAMBDA_ELECTROSTATICS_VACUUM, + ) + solution_protocol = absolv.config.EquilibriumProtocol( + production_protocol=absolv.config.HREMDProtocol( + n_steps_per_cycle=500, + n_cycles=1000, + integrator=femto.md.config.LangevinIntegrator( + timestep=4.0 * openmm.unit.femtosecond ), - lambda_sterics=absolv.config.DEFAULT_LAMBDA_STERICS_SOLVENT, - lambda_electrostatics=absolv.config.DEFAULT_LAMBDA_ELECTROSTATICS_SOLVENT, - ) + trajectory_interval=1, + trajectory_enforce_pbc=True, + ), + lambda_sterics=absolv.config.DEFAULT_LAMBDA_STERICS_SOLVENT, + lambda_electrostatics=absolv.config.DEFAULT_LAMBDA_ELECTROSTATICS_SOLVENT, + ) config = absolv.config.Config( temperature=temperature, pressure=pressure, - alchemical_protocol_a=vacuum_protocol, - alchemical_protocol_b=solvent_protocol, + alchemical_protocol_a=solvent_a_protocol + if solvent_a_protocol is not None + else (vacuum_protocol if solvent_a is None else solution_protocol), + alchemical_protocol_b=solvent_b_protocol + if solvent_b_protocol is not None + else (vacuum_protocol if solvent_b is None else solution_protocol), ) solute_mol = openff.toolkit.Molecule.from_rdkit( smee.mm._utils.topology_to_rdkit(solute), allow_undefined_stereo=True, ) - solvent_mol = openff.toolkit.Molecule.from_rdkit( - smee.mm._utils.topology_to_rdkit(solvent), - allow_undefined_stereo=True, - ) system_config = absolv.config.System( solutes={solute_mol.to_smiles(mapped=True): 1}, - solvent_a=None, - solvent_b={solvent_mol.to_smiles(mapped=True): n_solvent}, + solvent_a=_to_solvent_dict(solvent_a, n_solvent_a), + solvent_b=_to_solvent_dict(solvent_b, n_solvent_b), ) topologies = { - "solvent-a": smee.TensorSystem([solute], [1], is_periodic=False), - "solvent-b": smee.TensorSystem( - [solute, solvent], [1, n_solvent], is_periodic=True - ), + "solvent-a": smee.TensorSystem([solute], [1], is_periodic=False) + if solvent_a is None + else smee.TensorSystem([solute, solvent_a], [1, n_solvent_a], is_periodic=True), + "solvent-b": smee.TensorSystem([solute], [1], is_periodic=False) + if solvent_b is None + else smee.TensorSystem([solute, solvent_b], [1, n_solvent_b], is_periodic=True), } pressures = { - "solvent-a": None, - "solvent-b": pressure.value_in_unit(openmm.unit.atmosphere), + "solvent-a": None + if solvent_a is None + else pressure.value_in_unit(openmm.unit.atmosphere), + "solvent-b": None + if solvent_b is None + else pressure.value_in_unit(openmm.unit.atmosphere), } for phase, topology in topologies.items(): @@ -173,8 +198,12 @@ def _parameterize( config, prepared_system_a, prepared_system_b, "CUDA", output_dir, parallel=True ) - solvent_b_output = _extract_pure_solvent(force_field, output_dir / "solvent-b") - torch.save(solvent_b_output, output_dir / "solvent-b" / "pure.pt") + if solvent_a is not None: + solvent_a_output = _extract_pure_solvent(force_field, output_dir / "solvent-a") + torch.save(solvent_a_output, output_dir / "solvent-a" / "pure.pt") + if solvent_b is not None: + solvent_b_output = _extract_pure_solvent(force_field, output_dir / "solvent-b") + torch.save(solvent_b_output, output_dir / "solvent-b" / "pure.pt") return result diff --git a/smee/tests/mm/test_fe.py b/smee/tests/mm/test_fe.py index f74bf4f..f739740 100644 --- a/smee/tests/mm/test_fe.py +++ b/smee/tests/mm/test_fe.py @@ -54,7 +54,9 @@ def test_fe_ops(tmp_cwd): output_dir = pathlib.Path("CCO") output_dir.mkdir(parents=True, exist_ok=True) - smee.mm.generate_dg_solv_data(top_solute, top_solvent, ff, output_dir=output_dir) + smee.mm.generate_dg_solv_data( + top_solute, None, top_solvent, ff, output_dir=output_dir + ) params = ff.potentials_by_type["Electrostatics"].parameters params.requires_grad_(True)