diff --git a/tests/neurax_vs_neuron/test_branch.py b/tests/neurax_vs_neuron/test_branch.py index c53e488f..4a11e3c1 100644 --- a/tests/neurax_vs_neuron/test_branch.py +++ b/tests/neurax_vs_neuron/test_branch.py @@ -1,7 +1,7 @@ -from jax.config import config +import jax -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") import os diff --git a/tests/neurax_vs_neuron/test_cell.py b/tests/neurax_vs_neuron/test_cell.py index f752b2e6..b4530686 100644 --- a/tests/neurax_vs_neuron/test_cell.py +++ b/tests/neurax_vs_neuron/test_cell.py @@ -1,7 +1,7 @@ -from jax.config import config +import jax -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") import os diff --git a/tests/neurax_vs_neuron/test_comp.py b/tests/neurax_vs_neuron/test_comp.py index 343d2711..3d6d6276 100644 --- a/tests/neurax_vs_neuron/test_comp.py +++ b/tests/neurax_vs_neuron/test_comp.py @@ -1,7 +1,7 @@ -from jax.config import config +import jax -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") import os diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index 603bb3f8..2e960a37 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -1,7 +1,7 @@ -from jax.config import config +import jax -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") import jax.numpy as jnp diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 5297645b..c1212750 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -1,3 +1,8 @@ +import jax + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") + import jax.numpy as jnp import neurax as nx diff --git a/tests/test_swc.py b/tests/test_swc.py index eb3d8adf..b4ef78c7 100644 --- a/tests/test_swc.py +++ b/tests/test_swc.py @@ -1,7 +1,7 @@ -from jax.config import config +import jax -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") import os