From 7b540de2b7831829ab4ec3fbd60b0d393a683da9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 14 Nov 2024 17:28:26 +0100 Subject: [PATCH] Added CI job with TSAN and free-threading Use bazel to run tests --- .github/workflows/tsan.yaml | 153 ++++++++++++++++++++++++++++ .tsan_ignore | 38 +++++++ build/requirements_lock_3_13_ft.txt | 68 ++----------- tests/api_test.py | 114 ++++++--------------- tests/jaxpr_effects_test.py | 13 ++- 5 files changed, 238 insertions(+), 148 deletions(-) create mode 100644 .github/workflows/tsan.yaml create mode 100644 .tsan_ignore diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml new file mode 100644 index 000000000000..b9756124e26a --- /dev/null +++ b/.github/workflows/tsan.yaml @@ -0,0 +1,153 @@ +name: CI - Free-threading and Thread Sanitizer (nightly) + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + schedule: + - cron: "0 12 * * *" # Daily at 12:00 UTC + workflow_dispatch: # allows triggering the workflow run manually + pull_request: # Automatically trigger on pull requests affecting this file + # branches: + # - main + paths: + - '**/workflows/tsan.yaml' + +jobs: + tsan: + runs-on: linux-x86-n2-64 + container: + image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 + strategy: + fail-fast: false + defaults: + run: + shell: bash -l {0} + steps: + # Install git before actions/checkout as otherwise it will download the code with the GitHub + # REST API and therefore any subsequent git commands will fail. + - name: Install clang 18 + env: + DEBIAN_FRONTEND: noninteractive + run: | + apt update + apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ + zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ + libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ + libffi-dev liblzma-dev + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: jax + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: python/cpython + path: cpython + ref: "3.13" + - name: Build CPython with TSAN enabled + run: | + cd cpython + mkdir ${GITHUB_WORKSPACE}/cpython-tsan + CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpython-tsan --disable-gil --with-thread-sanitizer + make -j64 + make install + # Check whether free-threading mode is enabled + PYTHON_GIL=0 ${GITHUB_WORKSPACE}/cpython-tsan/bin/python3 -c "import sys; assert not sys._is_gil_enabled()" + + # Create archive to be used with bazel as hermetic python: + cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan + - name: Build and install JAX + run: | + cd jax + + export PYTHON_SHA256=($(sha256sum ${GITHUB_WORKSPACE}/python-tsan.tgz)) + echo "Python sha256: ${PYTHON_SHA256}" + + ${GITHUB_WORKSPACE}/cpython-tsan/bin/python3 build/build.py build --wheels=jaxlib \ + --python_version=3.13-ft \ + --bazel_options=--repo_env=HERMETIC_PYTHON_URL="file://${GITHUB_WORKSPACE}/python-tsan.tgz" \ + --bazel_options=--repo_env=HERMETIC_PYTHON_SHA256=${PYTHON_SHA256} \ + --bazel_options=--repo_env=HERMETIC_PYTHON_PREFIX="cpython-tsan/" \ + --bazel_options=--color=yes \ + --bazel_options=--copt=-fsanitize=thread \ + --bazel_options=--linkopt="-fsanitize=thread" \ + --bazel_options=--copt=-g \ + --clang_path=/usr/bin/clang-18 + - name: Run tests + timeout-minutes: 120 + env: + JAX_NUM_GENERATED_CASES: 1 + JAX_ENABLE_X64: true + JAX_SKIP_SLOW_TESTS: true + PY_COLORS: 1 + run: | + cd jax + echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" + echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" + echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" + + # As we do not have yet free-threading support + # there will be the following warning: + # RuntimeWarning: The global interpreter lock (GIL) has been enabled to load module 'jaxlib.utils', + # which has not declared that it can run safely without the GIL. + # To avoid that we temporarily define PYTHON_GIL + export PYTHON_GIL=0 + + # Set symlink to the bazel executable + bazel_exec=($(ls bazel-*)) + ln -s ${bazel_exec} bazel + + # Create tsan suppressions file + cat << EOF > $PWD/.tsan_ignore + # false-positive caused because we haven't tsan-instrumented libgcc_s. Multiple threads + # are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. + race:llvm::RuntimeDyldELF::registerEHFrames + + # https://github.com/python/cpython/issues/128050 + race:partial_vectorcall_fallback + + # https://github.com/python/cpython/issues/128100 + race:ensure_nonmanaged_dict + + # https://github.com/openxla/xla/issues/20686 + race:dnnl_sgemm + + # https://github.com/numpy/numpy/issues/28041 + race:get_initial_from_ufunc + + # https://github.com/numpy/numpy/issues/28042 + race:PyArray_UpdateFlags + + # https://github.com/python/cpython/issues/128130 + race_top:run_eval_code_obj + + race:dump_traceback + + # https://github.com/numpy/numpy/issues/28045 not sure about this one + race:arraymethod_dealloc + + # https://github.com/python/cpython/issues/128133 + race:bytes_hash + + # https://github.com/python/cpython/issues/128137 + race:immortalize_interned + + # https://github.com/python/cpython/issues/128144 + race_top:PyMember_GetOne + + # https://github.com/python/cpython/issues/128657 + race:py_digest_by_name + EOF + + ./bazel test \ + --python_version=3.13-ft \ + --//jax:build_jaxlib=false \ + --repo_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \ + --repo_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \ + --repo_env=JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS \ + --repo_env=PYTHON_GIL=$PYTHON_GIL \ + --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.tsan_ignore \ + --test_env=JAX_TEST_NUM_THREADS=8 \ + --nocache_test_results \ + --test_output=all \ + //tests:cpu_tests diff --git a/.tsan_ignore b/.tsan_ignore new file mode 100644 index 000000000000..64523713d0de --- /dev/null +++ b/.tsan_ignore @@ -0,0 +1,38 @@ +# I believe this is a false-positive caused because we haven't tsan-instrumented libgcc_s. Multiple threads +# are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. +race:llvm::RuntimeDyldELF::registerEHFrames + +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback + +# https://github.com/python/cpython/issues/128100 +race:ensure_nonmanaged_dict + +# https://github.com/openxla/xla/issues/20686 +race:dnnl_sgemm + +# https://github.com/numpy/numpy/issues/28041 +race:get_initial_from_ufunc + +# https://github.com/numpy/numpy/issues/28042 +race:PyArray_UpdateFlags + +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj + +race:dump_traceback + +# https://github.com/numpy/numpy/issues/28045 not sure about this one +race:arraymethod_dealloc + +# https://github.com/python/cpython/issues/128133 +race:bytes_hash + +# https://github.com/python/cpython/issues/128137 +race:immortalize_interned + +# https://github.com/python/cpython/issues/128144 +race_top:PyMember_GetOne + +# https://github.com/python/cpython/issues/128657 +race:py_digest_by_name diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index dfefaf042a21..2700e140e0ce 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -4,6 +4,12 @@ # # pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in # + +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +numpy + + absl-py==2.1.0 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff @@ -328,68 +334,6 @@ mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt -numpy==2.2.1 ; python_version >= "3.13" \ - --hash=sha256:059e6a747ae84fce488c3ee397cee7e5f905fd1bda5fb18c66bc41807ff119b2 \ - --hash=sha256:08ef779aed40dbc52729d6ffe7dd51df85796a702afbf68a4f4e41fafdc8bda5 \ - --hash=sha256:164a829b6aacf79ca47ba4814b130c4020b202522a93d7bff2202bfb33b61c60 \ - --hash=sha256:26c9c4382b19fcfbbed3238a14abf7ff223890ea1936b8890f058e7ba35e8d71 \ - --hash=sha256:27f5cdf9f493b35f7e41e8368e7d7b4bbafaf9660cba53fb21d2cd174ec09631 \ - --hash=sha256:31b89fa67a8042e96715c68e071a1200c4e172f93b0fbe01a14c0ff3ff820fc8 \ - --hash=sha256:32cb94448be47c500d2c7a95f93e2f21a01f1fd05dd2beea1ccd049bb6001cd2 \ - --hash=sha256:360137f8fb1b753c5cde3ac388597ad680eccbbbb3865ab65efea062c4a1fd16 \ - --hash=sha256:3683a8d166f2692664262fd4900f207791d005fb088d7fdb973cc8d663626faa \ - --hash=sha256:38efc1e56b73cc9b182fe55e56e63b044dd26a72128fd2fbd502f75555d92591 \ - --hash=sha256:3d03883435a19794e41f147612a77a8f56d4e52822337844fff3d4040a142964 \ - --hash=sha256:3ecc47cd7f6ea0336042be87d9e7da378e5c7e9b3c8ad0f7c966f714fc10d821 \ - --hash=sha256:40f9e544c1c56ba8f1cf7686a8c9b5bb249e665d40d626a23899ba6d5d9e1484 \ - --hash=sha256:4250888bcb96617e00bfa28ac24850a83c9f3a16db471eca2ee1f1714df0f957 \ - --hash=sha256:4511d9e6071452b944207c8ce46ad2f897307910b402ea5fa975da32e0102800 \ - --hash=sha256:45681fd7128c8ad1c379f0ca0776a8b0c6583d2f69889ddac01559dfe4390918 \ - --hash=sha256:48fd472630715e1c1c89bf1feab55c29098cb403cc184b4859f9c86d4fcb6a95 \ - --hash=sha256:4c86e2a209199ead7ee0af65e1d9992d1dce7e1f63c4b9a616500f93820658d0 \ - --hash=sha256:4dfda918a13cc4f81e9118dea249e192ab167a0bb1966272d5503e39234d694e \ - --hash=sha256:5062dc1a4e32a10dc2b8b13cedd58988261416e811c1dc4dbdea4f57eea61b0d \ - --hash=sha256:51faf345324db860b515d3f364eaa93d0e0551a88d6218a7d61286554d190d73 \ - --hash=sha256:526fc406ab991a340744aad7e25251dd47a6720a685fa3331e5c59fef5282a59 \ - --hash=sha256:53c09385ff0b72ba79d8715683c1168c12e0b6e84fb0372e97553d1ea91efe51 \ - --hash=sha256:55ba24ebe208344aa7a00e4482f65742969a039c2acfcb910bc6fcd776eb4355 \ - --hash=sha256:5b6c390bfaef8c45a260554888966618328d30e72173697e5cabe6b285fb2348 \ - --hash=sha256:5c5cc0cbabe9452038ed984d05ac87910f89370b9242371bd9079cb4af61811e \ - --hash=sha256:5edb4e4caf751c1518e6a26a83501fda79bff41cc59dac48d70e6d65d4ec4440 \ - --hash=sha256:61048b4a49b1c93fe13426e04e04fdf5a03f456616f6e98c7576144677598675 \ - --hash=sha256:676f4eebf6b2d430300f1f4f4c2461685f8269f94c89698d832cdf9277f30b84 \ - --hash=sha256:67d4cda6fa6ffa073b08c8372aa5fa767ceb10c9a0587c707505a6d426f4e046 \ - --hash=sha256:694f9e921a0c8f252980e85bce61ebbd07ed2b7d4fa72d0e4246f2f8aa6642ab \ - --hash=sha256:733585f9f4b62e9b3528dd1070ec4f52b8acf64215b60a845fa13ebd73cd0712 \ - --hash=sha256:7671dc19c7019103ca44e8d94917eba8534c76133523ca8406822efdd19c9308 \ - --hash=sha256:780077d95eafc2ccc3ced969db22377b3864e5b9a0ea5eb347cc93b3ea900315 \ - --hash=sha256:7ba9cc93a91d86365a5d270dee221fdc04fb68d7478e6bf6af650de78a8339e3 \ - --hash=sha256:89b16a18e7bba224ce5114db863e7029803c179979e1af6ad6a6b11f70545008 \ - --hash=sha256:9036d6365d13b6cbe8f27a0eaf73ddcc070cae584e5ff94bb45e3e9d729feab5 \ - --hash=sha256:93cf4e045bae74c90ca833cba583c14b62cb4ba2cba0abd2b141ab52548247e2 \ - --hash=sha256:9ad014faa93dbb52c80d8f4d3dcf855865c876c9660cb9bd7553843dd03a4b1e \ - --hash=sha256:9b1d07b53b78bf84a96898c1bc139ad7f10fda7423f5fd158fd0f47ec5e01ac7 \ - --hash=sha256:a7746f235c47abc72b102d3bce9977714c2444bdfaea7888d241b4c4bb6a78bf \ - --hash=sha256:aa3017c40d513ccac9621a2364f939d39e550c542eb2a894b4c8da92b38896ab \ - --hash=sha256:b34d87e8a3090ea626003f87f9392b3929a7bbf4104a05b6667348b6bd4bf1cd \ - --hash=sha256:b541032178a718c165a49638d28272b771053f628382d5e9d1c93df23ff58dbf \ - --hash=sha256:ba5511d8f31c033a5fcbda22dd5c813630af98c70b2661f2d2c654ae3cdfcfc8 \ - --hash=sha256:bc8a37ad5b22c08e2dbd27df2b3ef7e5c0864235805b1e718a235bcb200cf1cb \ - --hash=sha256:bff7d8ec20f5f42607599f9994770fa65d76edca264a87b5e4ea5629bce12268 \ - --hash=sha256:c1ad395cf254c4fbb5b2132fee391f361a6e8c1adbd28f2cd8e79308a615fe9d \ - --hash=sha256:f1d09e520217618e76396377c81fba6f290d5f926f50c35f3a5f72b01a0da780 \ - --hash=sha256:f3eac17d9ec51be534685ba877b6ab5edc3ab7ec95c8f163e5d7b39859524716 \ - --hash=sha256:f419290bc8968a46c4933158c91a0012b7a99bb2e465d5ef5293879742f8797e \ - --hash=sha256:f62aa6ee4eb43b024b0e5a01cf65a0bb078ef8c395e8713c6e8a12a697144528 \ - --hash=sha256:f74e6fdeb9a265624ec3a3918430205dff1df7e95a230779746a6af78bc615af \ - --hash=sha256:f9b57eaa3b0cd8db52049ed0330747b0364e899e8a606a624813452b8203d5f7 \ - --hash=sha256:fce4f615f8ca31b2e61aa0eb5865a21e14f5629515c9151850aa936c02a1ee51 - # via - # -r build/requirements.in - # contourpy - # matplotlib - # ml-dtypes - # scipy opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac diff --git a/tests/api_test.py b/tests/api_test.py index fbe3610b1f08..af56b28b2e93 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -96,38 +96,6 @@ def my_function(): jitted = jit(my_function) self.assertEqual(repr(jitted), f"") - def test_fun_name(self): - def my_function(): - return - - with self.subTest("function"): - jitted = jit(my_function) - self.assertEqual( - jitted.__getstate__()["function_name"], my_function.__name__ - ) - with self.subTest("default_partial"): - my_partial = partial(my_function) - jitted = jit(my_partial) - self.assertEqual( - jitted.__getstate__()["function_name"], my_function.__name__ - ) - with self.subTest("nested_default_partial"): - my_partial = partial(partial(my_function)) - jitted = jit(my_partial) - self.assertEqual( - jitted.__getstate__()["function_name"], my_function.__name__ - ) - with self.subTest("named_partial"): - my_partial = partial(my_function) - my_partial.__name__ = "my_partial" - jitted = jit(my_partial) - self.assertEqual( - jitted.__getstate__()["function_name"], my_partial.__name__ - ) - with self.subTest("lambda"): - jitted = jit(lambda: my_function()) - self.assertEqual(jitted.__getstate__()["function_name"], "") - def test_jit_repr_errors(self): class Callable: def __call__(self): pass @@ -320,14 +288,14 @@ def test_jit_default_device(self, module): self.assertEqual(f(1).devices(), system_default_devices) def test_jit_default_platform(self): - with jax.default_device("cpu"): - result = jax.jit(lambda x: x + 1)(1) - self.assertEqual(result.device.platform, "cpu") - self.assertEqual(result.device, jax.local_devices(backend="cpu")[0]) + with jax.default_device("cpu"): + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, "cpu") + self.assertEqual(result.device, jax.local_devices(backend="cpu")[0]) - result = jax.jit(lambda x: x + 1)(1) - self.assertEqual(result.device.platform, jax.default_backend()) - self.assertEqual(result.device, jax.local_devices()[0]) + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, jax.default_backend()) + self.assertEqual(result.device, jax.local_devices()[0]) def test_complex_support(self): self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j) @@ -664,7 +632,6 @@ def f(x): python_should_be_executing = False jit(f)(3) - @jtu.thread_unsafe_test() # GC effects aren't predictable with threads def test_jit_cache_clear(self): @jit def f(x, y): @@ -1327,7 +1294,7 @@ def f(x, y, *args, **kwargs): return y['hi'] + args[1] + sum(kwargs.values()) lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.) - hlo_str = lowered.as_text("stablehlo", debug_info=True) + hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo')) self.assertNotIn("\"x\"", hlo_str) self.assertIn("y['hi']", hlo_str) self.assertNotIn("args[0]", hlo_str) @@ -1335,7 +1302,10 @@ def f(x, y, *args, **kwargs): self.assertIn("kwargs['z']", hlo_str) self.assertIn("kwargs['w']", hlo_str) - hlo_str = lowered.as_text("stablehlo", debug_info=False) + hlo_str = mlir.module_to_string( + lowered.compiler_ir('stablehlo'), + enable_debug_info=False, + ) for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"): self.assertNotIn(s, hlo_str) @@ -1344,10 +1314,9 @@ def test_jit_lower_arg_info_static_argnums(self, static_argnums): def f(x, y, *args, **kwargs): return y['hi'] + args[1] + sum(kwargs.values()) - lowered = jax.jit(f, static_argnums=static_argnums).lower( - (1.,), {'hi': 2.}, 3., 4., z=5., w=6.) - - hlo_str = lowered.as_text("stablehlo", debug_info=True) + ir = jax.jit(f, static_argnums=static_argnums).lower( + (1.,), {'hi': 2.}, 3., 4., z=5., w=6.).compiler_ir('stablehlo') + hlo_str = mlir.module_to_string(ir) self.assertNotIn("\"x\"", hlo_str) self.assertIn("y['hi']", hlo_str) self.assertNotIn("args[0]", hlo_str) @@ -1355,7 +1324,7 @@ def f(x, y, *args, **kwargs): self.assertIn("kwargs['z']", hlo_str) self.assertIn("kwargs['w']", hlo_str) - hlo_str = lowered.as_text("stablehlo", debug_info=False) + hlo_str = mlir.module_to_string(ir, enable_debug_info=False) for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"): self.assertNotIn(s, hlo_str) @@ -1364,9 +1333,9 @@ def test_jit_lower_arg_info_static_argnames(self, static_argnames): def f(x, y, *args, **kwargs): return y['hi'] + args[1] + kwargs['z'] + kwargs['w'] - lowered = jax.jit(f, static_argnames=static_argnames).lower( - (1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.) - hlo_str = lowered.as_text("stablehlo", debug_info=True) + ir = jax.jit(f, static_argnames=static_argnames).lower( + (1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.).compiler_ir('stablehlo') + hlo_str = mlir.module_to_string(ir) self.assertNotIn("\"x\"", hlo_str) self.assertIn("y['hi']", hlo_str) self.assertNotIn("args[0]", hlo_str) @@ -1376,7 +1345,7 @@ def f(x, y, *args, **kwargs): self.assertNotIn("kwargs['a']", hlo_str) self.assertNotIn("kwargs['b']", hlo_str) - hlo_str = lowered.as_text("stablehlo", debug_info=False) + hlo_str = mlir.module_to_string(ir, enable_debug_info=False) for s in ( "\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']", "kwargs['a']", "kwargs['b']" @@ -1387,7 +1356,8 @@ def test_jit_lower_result_info(self): def f(x, y, z): return {'a': x, 'b': [y]} - hlo_str = jax.jit(f).lower(1., (2,), [3]).as_text("stablehlo", debug_info=True) + ir = jax.jit(f).lower(1., (2,), [3]).compiler_ir('stablehlo') + hlo_str = mlir.module_to_string(ir) self.assertIn("jax.result_info = \"['a']\"", hlo_str) self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str) @@ -1637,7 +1607,6 @@ def f(x, y, z, flag=False): assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0) assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0)) - @jtu.thread_unsafe_test() # Concurrent cache eviction means we may retrace. def test_grad_of_jit(self): side = [] @@ -1651,7 +1620,6 @@ def f(x): assert grad(f)(2.0) == 4.0 assert len(side) == 1 - @jtu.thread_unsafe_test() # Concurrent ache eviction means we may retrace. def test_jit_of_grad(self): side = [] @@ -2623,7 +2591,6 @@ def test_block_until_ready_mixed(self): self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False) self.assertEqual(pytree[3], 4) - @jtu.thread_unsafe_test() # Weakref destruction seems unpredictable with threads def test_devicearray_weakref_friendly(self): x = device_put(1.) y = weakref.ref(x) @@ -2772,7 +2739,6 @@ def f(x): self.assertEqual(count(), 1) - @jtu.thread_unsafe_test() # jit cache misses aren't thread safe def test_jit_infer_params_cache(self): def f(x): return x @@ -3363,7 +3329,6 @@ def test_grad_object_array_error(self): with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"): jax.grad(lambda x: x)(x) - @jtu.thread_unsafe_test() # logging isn't thread-safe def test_jit_compilation_time_logging(self): @api.jit def f(x): @@ -3452,7 +3417,6 @@ def test_trivial_computations(self): self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer()) self.assertEqual(z2, 1) - @jtu.thread_unsafe_test() # monkey-patching mlir.jaxpr_subcomp isn't thread-safe def test_nested_jit_hoisting(self): @api.jit def f(x, y): @@ -3490,7 +3454,6 @@ def mlir_jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs): self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul') self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add') - @jtu.thread_unsafe_test() # count_primitive_compiles isn't thread-safe def test_primitive_compilation_cache(self): with jtu.count_primitive_compiles() as count: lax.add(1, 2) @@ -4050,17 +4013,13 @@ def __jax_array__(self): a2 = jnp.array(((x, x), [x, x])) self.assertAllClose(np.array(((1, 1), (1, 1))), a2) - @jtu.thread_unsafe_test() # count_jit_tracing_cache_miss() isn't thread-safe def test_eval_shape_weak_type(self): # https://github.com/jax-ml/jax/issues/23302 arr = jax.numpy.array(1) - def f(x): - return jax.numpy.array(x) - with jtu.count_jit_tracing_cache_miss() as count: - jax.eval_shape(f, 1) - out = jax.eval_shape(f, 1) + jax.eval_shape(jax.numpy.array, 1) + out = jax.eval_shape(jax.numpy.array, 1) self.assertEqual(count(), 1) self.assertTrue(out.weak_type) @@ -4179,7 +4138,6 @@ def test_dot_precision_flag(self): jaxpr = jax.make_jaxpr(jnp.dot)(x, x) self.assertIn('Precision.HIGH', str(jaxpr)) - @jtu.thread_unsafe_test() # Updating global configs is not thread-safe. def test_dot_precision_forces_retrace(self): num_traces = 0 @@ -4352,7 +4310,6 @@ def test_jnp_array_doesnt_device_put(self): api.make_jaxpr(lambda: jnp.array(3))() self.assertEqual(count(), 0) - @jtu.thread_unsafe_test() # Updating global configs is not thread-safe. def test_rank_promotion_forces_retrace(self): num_traces = 0 @@ -4371,7 +4328,7 @@ def f_jit(x): for f in [f_jit, f_cond]: # Use _read() to read the flag value rather than threadlocal value. - allow_promotion = jax.numpy_rank_promotion.get_global() + allow_promotion = config._read("jax_numpy_rank_promotion") try: config.update("jax_numpy_rank_promotion", "allow") num_traces = 0 @@ -4393,9 +4350,9 @@ def f(x): self.assertGreaterEqual(num_traces, 2) nt = num_traces f(x) - self.assertEqual(num_traces, nt) + self.assertEqual(num_traces, nt + 1) f(x) - self.assertEqual(num_traces, nt) + self.assertEqual(num_traces, nt + 1) finally: config.update("jax_numpy_rank_promotion", allow_promotion) @@ -4493,7 +4450,6 @@ def foo(x, y, z): self.assertEqual(jfoo.__qualname__, f"make_jaxpr({foo.__qualname__})") self.assertEqual(jfoo.__module__, "jax") - @jtu.thread_unsafe_test() # Concurrent cache eviction means we may retrace def test_inner_jit_function_retracing(self): # https://github.com/jax-ml/jax/issues/7155 inner_count = outer_count = 0 @@ -4541,7 +4497,6 @@ def test_invalid_value_device_put(self): with self.assertRaisesRegex(ValueError, r".*Received invalid value.*"): jax.device_put(jnp.arange(8), 'cpu') - @jtu.thread_unsafe_test() # logging is not thread-safe def test_clear_cache(self): @jax.jit def add(x): @@ -4560,7 +4515,6 @@ def add(x): tracing_add_count += 1 self.assertEqual(tracing_add_count, 2) - @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations(self): @jax.jit def f(x, y): @@ -4620,7 +4574,6 @@ def f(x, y): msg = cm.output[0] self.assertIn("tracing context doesn't match", msg) - @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_new_function_in_loop(self): @jax.jit def f(x, y): @@ -4642,7 +4595,6 @@ def f(x, y): _, msg = cm.output self.assertIn('another function defined on the same line', msg) - @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_unpacks_transforms(self): # Tests that the explain_tracing_cache_miss() function does not throw an # error when unpacking `transforms` with a length greater than 3. @@ -4739,7 +4691,6 @@ def test_mesh_creation_error_message(self): with self.assertRaisesRegex(ValueError, "ndim of its first argument"): jax.sharding.Mesh(jax.devices(), ("x", "y")) - @jtu.thread_unsafe_test() # weakref gc doesn't seem predictable def test_jit_boundmethod_reference_cycle(self): class A: def __init__(self): @@ -4878,7 +4829,6 @@ class RematTest(jtu.JaxTestCase): ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) - @jtu.thread_unsafe_test() # monkey patches sin_p and cos_p def test_remat_basic(self, remat): @remat def g(x): @@ -5216,7 +5166,6 @@ def f_yesremat(x): ('_policy', partial(jax.remat, policy=lambda *_, **__: False)), ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) - @jtu.thread_unsafe_test() # monkey patches sin_p def test_remat_no_redundant_flops(self, remat): # see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584 @@ -6460,7 +6409,6 @@ def f(x): self.assertIn(' sin ', str(jaxpr)) self.assertIn(' cos ', str(jaxpr)) - @jtu.thread_unsafe_test() # logging isn't thread-safe def test_remat_residual_logging(self): def f(x): x = jnp.sin(x) @@ -9678,8 +9626,11 @@ def foo_bwd(_, g): foo.defvjp(foo_fwd, foo_bwd) - with config.custom_vjp_disable_shape_check(True): + try: + jax.config.update('jax_custom_vjp_disable_shape_check', True) jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + finally: + jax.config.update('jax_custom_vjp_disable_shape_check', False) def test_bwd_rule_can_produce_list_or_tuple(self): @jax.custom_vjp @@ -11136,7 +11087,6 @@ def test_call_wrapped_second_phase_cleanup(self): class EnvironmentInfoTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) - @jtu.thread_unsafe_test() def test_print_environment_info(self, return_string): # Flush stdout buffer before checking. sys.stdout.flush() @@ -11164,8 +11114,6 @@ def test_autodidax_smoketest(self): spec.loader.exec_module(autodidax_module) class GarbageCollectionTest(jtu.JaxTestCase): - - @jtu.thread_unsafe_test() # GC isn't predictable def test_xla_gc_callback(self): # https://github.com/jax-ml/jax/issues/14882 x_np = np.arange(10, dtype='int32') diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 2d63a48341c3..2e91792aa950 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import contextlib import threading import unittest @@ -34,7 +34,6 @@ import numpy as np config.parse_flags_with_absl() -jtu.request_cpu_devices(2) effect_p = core.Primitive('effect') effect_p.multiple_results = True @@ -133,6 +132,15 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out mlir.register_lowering(callback_p, callback_effect_lowering) +_exit_stack = contextlib.ExitStack() + +def setUpModule(): + _exit_stack.enter_context(jtu.set_host_platform_device_count(2)) + +def tearDownModule(): + _exit_stack.close() + + class JaxprEffectsTest(jtu.JaxTestCase): def test_trivial_jaxpr_has_no_effects(self): @@ -269,7 +277,6 @@ def f(x): self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect}) -@jtu.thread_unsafe_test_class() # because of mlir.register_lowering calls class EffectfulJaxprLoweringTest(jtu.JaxTestCase): def setUp(self):