diff --git a/smee/geometry.py b/smee/geometry.py index 36f9034..95bed7c 100644 --- a/smee/geometry.py +++ b/smee/geometry.py @@ -199,27 +199,27 @@ def _build_v_site_coord_frames( origin = weighted_coords[:, 0, :] xy_plane = weighted_coords[:, 1:, :] - xy_plane /= torch.norm(xy_plane, dim=-1).unsqueeze(-1) + xy_plane_hat = xy_plane / torch.norm(xy_plane, dim=-1).unsqueeze(-1) - x_hat = xy_plane[:, 0, :] + x_hat = xy_plane_hat[:, 0, :] - z_hat = torch.cross(x_hat, xy_plane[:, 1, :]) - z_hat_norm = torch.norm(z_hat, dim=-1).unsqueeze(-1) - z_hat_norm = torch.where( - torch.isclose(z_hat_norm, smee.utils.tensor_like(0.0, other=z_hat_norm)), - smee.utils.tensor_like(1.0, other=z_hat_norm), - z_hat_norm, + z = torch.cross(x_hat, xy_plane_hat[:, 1, :]) + z_norm = torch.norm(z, dim=-1).unsqueeze(-1) + z_norm_clamped = torch.where( + torch.isclose(z_norm, smee.utils.tensor_like(0.0, other=z_norm)), + smee.utils.tensor_like(1.0, other=z_norm), + z_norm, ) - z_hat /= z_hat_norm - - y_hat = torch.cross(z_hat, x_hat) - y_hat_norm = torch.norm(y_hat, dim=-1).unsqueeze(-1) - y_hat_norm = torch.where( - torch.isclose(y_hat_norm, smee.utils.tensor_like(0.0, other=y_hat_norm)), - smee.utils.tensor_like(1.0, other=y_hat_norm), - y_hat_norm, + z_hat = z / z_norm_clamped + + y = torch.cross(z_hat, x_hat) + y_norm = torch.norm(y, dim=-1).unsqueeze(-1) + y_norm_clamped = torch.where( + torch.isclose(y_norm, smee.utils.tensor_like(0.0, other=y_norm)), + smee.utils.tensor_like(1.0, other=y_norm), + y_norm, ) - y_hat /= y_hat_norm + y_hat = y / y_norm_clamped stacked_frames[0].append(origin) stacked_frames[1].append(x_hat) diff --git a/smee/tests/test_geometry.py b/smee/tests/test_geometry.py index 07e90c3..b405749 100644 --- a/smee/tests/test_geometry.py +++ b/smee/tests/test_geometry.py @@ -352,3 +352,29 @@ def test_add_v_site_coords(conformer, v_site_coords, cat_dim, mocker): assert coordinates.shape == expected_coords.shape assert torch.allclose(coordinates, expected_coords) + + +def test_add_v_site_coords_grad(v_site_force_field): + """Test that the gradients of functions of v-site coordinates can be computed, + and also gradients of the gradients (e.g. loss of forces involving v-sites). This + was found to be a bug upstream, yeilding a tensor modified in place error.""" + molecule = openff.toolkit.Molecule.from_mapped_smiles("[H:2][O:1][H:3]") + molecule.generate_conformers(n_conformers=1) + + conformer = torch.tensor(molecule.conformers[0].m_as(unit.angstrom)) + conformer.requires_grad_(True) + + interchange = openff.interchange.Interchange.from_smirnoff( + v_site_force_field, molecule.to_topology() + ) + force_field, [topology] = smee.converters.convert_interchange(interchange) + + assert topology.v_sites is not None + assert len(topology.v_sites.keys) > 0 + + v_site_coords = add_v_site_coords(topology.v_sites, conformer, force_field) + + grad = torch.autograd.grad(v_site_coords.sum(), conformer, create_graph=True)[0] + + loss = grad.sum() + loss.backward()