Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smg/cellfinder default jax #381

Closed
wants to merge 9 commits into from
6 changes: 3 additions & 3 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ jobs:
include:
- os: macos-latest
python-version: "3.10"
- os: windows-latest
python-version: "3.10"

steps:
# Cache the Keras model so we don't have to remake it every time
Expand Down Expand Up @@ -115,10 +117,8 @@ jobs:
- name: Install test dependencies
run: |
python -m pip install --upgrade pip wheel
# Install cellfinder from the latest SHA on this branch
# Install cellfinder from the latest SHA on this branch (Keras with default JAX backend)
python -m pip install git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA
# Install tensorflow as keras' default backend
python -m pip install "tf-nightly==2.16.0.dev20240101"
# Install checked out copy of brainglobe-workflows
python -m pip install .[dev]

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_include_guard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ jobs:
with:
python-version: '3.10'

- name: Install cellfinder via pip, specifying tensorflow as keras' backend
run: python -m pip install -e ".[tf]"
- name: Install cellfinder via pip, specifying JAX as keras' backend
run: python -m pip install -e ".[jax]"

- name: Test (working) import
uses: jannekem/run-python-script-action@v1
Expand Down
30 changes: 22 additions & 8 deletions cellfinder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import warnings
from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec

# Check cellfinder is installed
try:
Expand All @@ -25,29 +26,42 @@


# If no backend is configured and installed for Keras, tools cannot be used
# Check backend is configured
# Check backend is configured (default: JAX)
# do not use default in getenv: any changes to environment after importing
# os module are not picked up except if directly modifying the dict
if not os.getenv("KERAS_BACKEND"):
os.environ["KERAS_BACKEND"] = "tensorflow"
warnings.warn(
"Keras backend not configured, automatically set to Tensorflow"
)
os.environ["KERAS_BACKEND"] = "jax"
warnings.warn("Keras backend not configured, automatically set to JAX")

# Check backend is installed
if os.getenv("KERAS_BACKEND") in ["tensorflow", "jax", "torch"]:
backend = os.getenv("KERAS_BACKEND")
backend = os.getenv("KERAS_BACKEND")
if backend in ["tensorflow", "jax", "torch"]:
try:
backend_package = "tf-nightly" if backend == "tensorflow" else backend
BACKEND_VERSION = version(backend_package)

warnings.warn(f"Using Keras with {backend} backend")

except PackageNotFoundError as e:
raise PackageNotFoundError(
f"{backend}, ({backend_package}) set as Keras backend "
f"but not installed"
) from e
else:
raise PackageNotFoundError(
"Keras backend must be one of 'tensorflow', 'jax', or 'torch'"
f"Keras backend must be one of 'tensorflow', "
f"'jax', or 'torch' (not {backend})."
)

# If TF is installed but backend not set to TF, raise an error
tf_spec = find_spec("tensorflow")
if tf_spec:
if tf_spec.name != backend:
raise ImportError(
f"Tensorflow package installed"
f"but not set as Keras backend ({backend})"
) # replace by ImportWarning?


__author__ = "Adam Tyson, Christian Niedworok, Charly Rousseau"
__license__ = "BSD-3-Clause"
21 changes: 11 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ dependencies = [
"scikit-image",
"scikit-learn",
"keras>=3.0.0",
"jax==0.4.20",
"jaxlib==0.4.20",
"tifffile",
"tqdm",
]
Expand Down Expand Up @@ -63,10 +65,6 @@ napari = [
tf = [
"tf-nightly==2.16.0.dev20240101", # pinning to same TF as Keras 3.0
]
jax = [
"jax==0.4.20",
"jaxlib==0.4.20"
]
torch = [
"torch==2.1.0"
]
Expand Down Expand Up @@ -120,13 +118,13 @@ markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
legacy_tox_ini = """
# For more information about tox, see https://tox.readthedocs.io/en/latest/
[tox]
envlist = py{39,310}-{tf,jax}
envlist = py{39,310}-{default-jax, tf}
isolated_build = true

[gh-actions]
python =
3.9: py39-{tf,jax} # On GA python=3.9 job, run tox with the tf and jax environments
3.10: py310-{tf,jax} # On GA python=3.10 job, run tox with the tf and jax environments
3.9: py39-{default-jax,tf} # On GA python=3.9 job, run tox with the jax (default) and tf environments
3.10: py310-{default-jax, tf} # On GA python=3.10 job, run tox with jax (default) and tf environments

[testenv]
commands = python -m pytest -v --color=yes
Expand All @@ -142,10 +140,10 @@ deps =
extras =
napari
tf: tf
jax: jax
# default-jax: ---now shipping JAX with normal dependencies
setenv =
tf: KERAS_BACKEND = tensorflow
jax: KERAS_BACKEND = jax
tf: KERAS_BACKEND = tensorflow # if I dont set this, it will default to JAX
# default-jax: ---now shipping JAX with normal dependencies
passenv =
NUMBA_DISABLE_JIT
CI
Expand All @@ -154,4 +152,7 @@ passenv =
XAUTHORITY
NUMPY_EXPERIMENTAL_ARRAY_FUNCTION
PYVISTA_OFF_SCREEN
platform =
default-jax: # any
tf: linux|darwin # only run if platform is either linux or darwin
"""
Loading