Skip to content

Commit

Permalink
Fix incorrect dG contribution of pure solvent when re-weighting (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 10, 2024
1 parent c01d912 commit a1bcf93
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 107 deletions.
156 changes: 82 additions & 74 deletions smee/mm/_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,23 @@ def _compute_energy(
return energy


def _compute_grads_solvent(
force_field: smee.TensorForceField,
theta: tuple[torch.Tensor, ...],
output_dir: pathlib.Path,
) -> tuple[torch.Tensor, ...]:
device = force_field.potentials[0].parameters.device

system, xyz, box, *_ = torch.load(output_dir / "pure.pt")
system.to(device)

with torch.enable_grad():
energy = _compute_energy(system, force_field, xyz, box)
grads = torch.autograd.grad(energy.mean(), theta)

return grads


def compute_dg_and_grads(
force_field: smee.TensorForceField,
theta: tuple[torch.Tensor, ...],
Expand All @@ -366,66 +383,47 @@ def compute_dg_and_grads(
)

f_i = mbar.compute_free_energy_differences()["Delta_f"][0, :]
dg = (f_i[-1] - f_i[0]) / beta

grads = ()

if len(theta) > 0:
with torch.enable_grad():
energy = _compute_energy(system, force_field, xyz_0, box_0)
grads = torch.autograd.grad(energy.mean(), theta)

return smee.utils.tensor_like(dg, force_field.potentials[0].parameters), grads


def compute_grads_solvent(
force_field: smee.TensorForceField,
theta: tuple[torch.Tensor, ...],
output_dir: pathlib.Path,
) -> tuple[torch.Tensor, ...]:
device = force_field.potentials[0].parameters.device
dg = (f_i[-1] - f_i[0]) / beta
dg = smee.utils.tensor_like(dg, force_field.potentials[0].parameters)

system, xyz, box, *_ = torch.load(output_dir / "pure.pt")
system.to(device)
if len(theta) == 0:
return dg, ()

grads = ()
with torch.enable_grad():
energy = _compute_energy(system, force_field, xyz_0, box_0)
grads = torch.autograd.grad(energy.mean(), theta)

if len(theta) > 0:
with torch.enable_grad():
energy = _compute_energy(system, force_field, xyz, box)
grads = torch.autograd.grad(energy.mean(), theta)
if (output_dir / "pure.pt").exists():
grads_solvent = _compute_grads_solvent(force_field, theta, output_dir)
grads = tuple(g - g_s for g, g_s in zip(grads, grads_solvent, strict=True))

return grads
return dg, grads


def reweight_dg_and_grads(
def _reweight_dg_and_grads(
system: smee.TensorSystem,
force_field: smee.TensorForceField,
xyz: torch.Tensor,
box: torch.Tensor | None,
u_0: torch.Tensor,
n_0: torch.Tensor,
beta: float,
pressure: float | None,
theta: tuple[torch.Tensor, ...],
output_dir: pathlib.Path,
) -> tuple[torch.Tensor, tuple[torch.Tensor, ...], float]:
import pymbar

device = force_field.potentials[0].parameters.device
dtype = force_field.potentials[0].parameters.dtype

system, beta, pressure, u_kn, n_k, xyz_0, box_0 = _load_samples(
output_dir, device, dtype
)
assert (box_0 is not None) == system.is_periodic
assert (box_0 is not None) == (pressure is not None)

u_0_old = u_kn[0, : n_k[0]]

with torch.enable_grad():
energy_0 = _compute_energy(system, force_field, xyz_0, box_0)
energy_new = _compute_energy(system, force_field, xyz, box)

u_0_new = energy_0.detach().clone() * beta
u_new = energy_new.detach().clone() * beta

if pressure is not None:
u_0_new += pressure * torch.det(box_0) * beta
u_new += pressure * torch.det(box) * beta

u_kn = numpy.stack([u_0_old.cpu().numpy(), u_0_new.cpu().numpy()])
n_k = numpy.array([n_k[0].cpu(), 0])
u_kn = numpy.stack([u_0.detach().cpu().numpy(), u_new.detach().cpu().numpy()])
n_k = numpy.array([n_0.detach().cpu().numpy().item(), 0])

mbar = pymbar.MBAR(
u_kn,
Expand All @@ -436,57 +434,67 @@ def reweight_dg_and_grads(
n_eff = mbar.compute_effective_sample_number().min().item()

f_i = mbar.compute_free_energy_differences()["Delta_f"][0, :]

dg = (f_i[-1] - f_i[0]) / beta
dg = smee.utils.tensor_like(dg, force_field.potentials[0].parameters)

weights = smee.utils.tensor_like(mbar.W_nk[:, 1], energy_0)
weights = smee.utils.tensor_like(mbar.W_nk[:, 1], energy_new)
grads = ()

if len(theta) > 0:
grads = torch.autograd.grad((energy_0 * weights).sum(), theta)
grads = torch.autograd.grad((energy_new * weights).sum(), theta)

return smee.utils.tensor_like(dg, energy_0), grads, n_eff
return dg, grads, n_eff


def reweight_grads_solvent(
def reweight_dg_and_grads(
force_field: smee.TensorForceField,
theta: tuple[torch.Tensor, ...],
output_dir: pathlib.Path,
) -> tuple[tuple[torch.Tensor, ...], float]:
import pymbar

) -> tuple[torch.Tensor, tuple[torch.Tensor, ...], float]:
device = force_field.potentials[0].parameters.device
dtype = force_field.potentials[0].parameters.dtype

system, xyz, box, beta, pressure, energy_old = torch.load(output_dir / "pure.pt")
system.to(device)

u_old = energy_old.detach().clone() * beta
system, beta, pressure, u_kn, n_k, xyz_0, box_0 = _load_samples(
output_dir, device, dtype
)
assert (box_0 is not None) == system.is_periodic
assert (box_0 is not None) == (pressure is not None)

if pressure is not None:
u_old += pressure * torch.det(box) * beta
u_0 = u_kn[0, : n_k[0]]
n_0 = n_k[0]

with torch.enable_grad():
energy_new = _compute_energy(system, force_field, xyz, box)
dg, grads, n_eff = _reweight_dg_and_grads(
system, force_field, xyz_0, box_0, u_0, n_0, beta, pressure, theta
)

u_new = energy_new.detach().clone() * beta
if not (output_dir / "pure.pt").exists():
return dg, grads, n_eff

if pressure is not None:
u_new += pressure * torch.det(box) * beta
system_solv, xyz_solv, box_solv, _, _, energy_solv = torch.load(
output_dir / "pure.pt"
)

u_kn = numpy.stack([u_old.cpu().numpy(), u_new.cpu().numpy()])
n_k = numpy.array([len(u_old), 0])
u_0_solv = energy_solv.detach().clone() * beta

mbar = pymbar.MBAR(
u_kn,
n_k,
solver_protocol=[{"method": "adaptive", "options": {"min_sc_iter": 0}}],
)
if pressure is not None:
u_0_solv += pressure * torch.det(box_solv) * beta

n_eff = mbar.compute_effective_sample_number().min().item()
n_0_solv = smee.utils.tensor_like([len(u_0_solv)], u_0_solv)

weights = smee.utils.tensor_like(mbar.W_nk[:, 1], energy_new)
grads = ()
dg_solv, grads_solv, n_eff_solv = _reweight_dg_and_grads(
system_solv,
force_field,
xyz_solv,
box_solv,
u_0_solv,
n_0_solv,
beta,
pressure,
theta,
)

if len(theta) > 0:
grads = torch.autograd.grad((energy_new * weights).sum(), theta)
dg -= dg_solv
grads = tuple(g - g_s for g, g_s in zip(grads, grads_solv, strict=True))

return grads, n_eff
return dg, grads, min(n_eff, n_eff_solv)
22 changes: 3 additions & 19 deletions smee/mm/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def reweight_ensemble_averages(
class _ComputeDGSolv(torch.autograd.Function):
@staticmethod
def forward(ctx, kwargs, *theta: torch.Tensor):
from smee.mm._fe import compute_dg_and_grads, compute_grads_solvent
from smee.mm._fe import compute_dg_and_grads

force_field = _unpack_force_field(
theta,
Expand All @@ -638,19 +638,11 @@ def forward(ctx, kwargs, *theta: torch.Tensor):
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

if (kwargs["fep_dir"] / "solvent-a" / "pure.pt").exists():
raise NotImplementedError("solvent-a is expected to be vacuum")

dg_solv_b_d_theta = compute_grads_solvent(
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

dg = dg_a - dg_b
dg_d_theta = [None] * len(theta)

for grad_idx, orig_idx in enumerate(needs_grad):
dg_d_theta[orig_idx] = dg_d_theta_b[grad_idx] - dg_d_theta_a[grad_idx]
dg_d_theta[orig_idx] -= dg_solv_b_d_theta[grad_idx]

ctx.save_for_backward(*dg_d_theta)

Expand All @@ -667,7 +659,7 @@ def backward(ctx, *grad_outputs):
class _ReweightDGSolv(torch.autograd.Function):
@staticmethod
def forward(ctx, kwargs, *theta: torch.Tensor):
from smee.mm._fe import reweight_dg_and_grads, reweight_grads_solvent
from smee.mm._fe import reweight_dg_and_grads

force_field = _unpack_force_field(
theta,
Expand All @@ -692,23 +684,15 @@ def forward(ctx, kwargs, *theta: torch.Tensor):
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

if (kwargs["fep_dir"] / "solvent-a" / "pure.pt").exists():
raise NotImplementedError("solvent-a is expected to be vacuum")

dg_solv_b_d_theta, n_effective_solv = reweight_grads_solvent(
force_field, theta_grad, kwargs["fep_dir"] / "solvent-b"
)

dg = -dg_a + dg_0 + dg_b
dg_d_theta = [None] * len(theta)

for grad_idx, orig_idx in enumerate(needs_grad):
dg_d_theta[orig_idx] = dg_d_theta_b[grad_idx] - dg_d_theta_a[grad_idx]
dg_d_theta[orig_idx] -= dg_solv_b_d_theta[grad_idx]

ctx.save_for_backward(*dg_d_theta)

return dg, min(n_effective_a, n_effective_b, n_effective_solv)
return dg, min(n_effective_a, n_effective_b)

@staticmethod
def backward(ctx, *grad_outputs):
Expand Down
16 changes: 2 additions & 14 deletions smee/tests/mm/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,16 +448,12 @@ def test_compute_dg_solv(mocker, tmp_path, mock_argon_tensors):
(torch.tensor(4.0).double(), (torch.tensor([[5.0, 6.0]]).double(),)),
],
)
mocker.patch(
"smee.mm._fe.compute_grads_solvent",
side_effect=[(torch.tensor([[8.0, 9.0]]).double(),)],
)

dg = smee.mm.compute_dg_solv(tensor_ff, tmp_path)
dg_dtheta = torch.autograd.grad(dg, params)[0]

assert torch.isclose(dg, torch.tensor(-3.0).double())
assert torch.allclose(dg_dtheta, torch.tensor([[-5.0, -6.0]]).double())
assert torch.allclose(dg_dtheta, torch.tensor([[3.0, 3.0]]).double())


def test_reweight_dg_solv(mocker, tmp_path, mock_argon_tensors):
Expand All @@ -473,18 +469,14 @@ def test_reweight_dg_solv(mocker, tmp_path, mock_argon_tensors):
(torch.tensor(5.0).double(), (torch.tensor([[6.0, 7.0]]).double(),), 8.0),
],
)
mocker.patch(
"smee.mm._fe.reweight_grads_solvent",
side_effect=[((torch.tensor([[9.0, 10.0]]).double(),), 11.0)],
)

dg_0 = torch.tensor(-3.0).double()

dg, n_eff = smee.mm.reweight_dg_solv(tensor_ff, tmp_path, dg_0, 3)
dg_dtheta = torch.autograd.grad(dg, params)[0]

assert torch.isclose(dg, torch.tensor(1.0).double())
assert torch.allclose(dg_dtheta, torch.tensor([[-5.0, -6.0]]).double())
assert torch.allclose(dg_dtheta, torch.tensor([[4.0, 4.0]]).double())

assert n_eff == 4.0

Expand All @@ -502,10 +494,6 @@ def test_reweight_dg_solv_error(mocker, tmp_path, mock_argon_tensors):
(torch.tensor(5.0).double(), (torch.tensor([[6.0, 7.0]]).double(),), 8.0),
],
)
mocker.patch(
"smee.mm._fe.reweight_grads_solvent",
side_effect=[((torch.tensor([[9.0, 10.0]]).double(),), 11.0)],
)

dg_0 = torch.tensor(-3.0).double()

Expand Down

0 comments on commit a1bcf93

Please sign in to comment.