From 4f9c52dfec0cd44ac3f39c7f532a77bf6e953c39 Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Mon, 6 Nov 2023 12:03:22 +0100 Subject: [PATCH] Fix warning due to jax.config --- tests/neurax_vs_neuron/test_branch.py | 6 +++--- tests/neurax_vs_neuron/test_cell.py | 6 +++--- tests/neurax_vs_neuron/test_comp.py | 6 +++--- tests/test_cell_matches_branch.py | 6 +++--- tests/test_make_trainable.py | 5 +++++ tests/test_swc.py | 6 +++--- 6 files changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/neurax_vs_neuron/test_branch.py b/tests/neurax_vs_neuron/test_branch.py index c53e488f..4a11e3c1 100644 --- a/tests/neurax_vs_neuron/test_branch.py +++ b/tests/neurax_vs_neuron/test_branch.py @@ -1,7 +1,7 @@ -from jax.config import config +import jax -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") import os diff --git a/tests/neurax_vs_neuron/test_cell.py b/tests/neurax_vs_neuron/test_cell.py index f752b2e6..b4530686 100644 --- a/tests/neurax_vs_neuron/test_cell.py +++ b/tests/neurax_vs_neuron/test_cell.py @@ -1,7 +1,7 @@ -from jax.config import config +import jax -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") import os diff --git a/tests/neurax_vs_neuron/test_comp.py b/tests/neurax_vs_neuron/test_comp.py index 343d2711..3d6d6276 100644 --- a/tests/neurax_vs_neuron/test_comp.py +++ b/tests/neurax_vs_neuron/test_comp.py @@ -1,7 +1,7 @@ -from jax.config import config +import jax -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") import os diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index 603bb3f8..2e960a37 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -1,7 +1,7 @@ -from jax.config import config +import jax -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") import jax.numpy as jnp diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 5297645b..c1212750 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -1,3 +1,8 @@ +import jax + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") + import jax.numpy as jnp import neurax as nx diff --git a/tests/test_swc.py b/tests/test_swc.py index eb3d8adf..b4ef78c7 100644 --- a/tests/test_swc.py +++ b/tests/test_swc.py @@ -1,7 +1,7 @@ -from jax.config import config +import jax -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") import os