diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py index c87ad2d8..fb8bacde 100644 --- a/pymbar/mbar_solvers.py +++ b/pymbar/mbar_solvers.py @@ -15,7 +15,6 @@ raise ImportError("Jax disabled by force_no_jax in mbar_solvers.py") from jax.config import config - config.update("jax_enable_x64", True) from jax.numpy import exp, sum, newaxis, diag, dot, s_ from jax.numpy import pad as npad