From f23c402c86dda7dc84805aec4e9290fdbc909022 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:21:06 +0000 Subject: [PATCH 01/50] replace tensorflow Tensor with keras tensor --- cellfinder/core/classify/resnet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cellfinder/core/classify/resnet.py b/cellfinder/core/classify/resnet.py index de6e720f..fb68dbdc 100644 --- a/cellfinder/core/classify/resnet.py +++ b/cellfinder/core/classify/resnet.py @@ -1,5 +1,8 @@ from typing import Callable, Dict, List, Literal, Optional, Tuple, Union +from keras import ( + KerasTensor as Tensor, # from tensorflow import Tensor # tf.Tensor +) from keras import Model from keras.initializers import Initializer from keras.layers import ( @@ -14,7 +17,6 @@ ZeroPadding3D, ) from keras.optimizers import Adam, Optimizer -from tensorflow import Tensor ##################################################################### # Define the types of ResNet From d508b9fc80334b4abd27d069639298b0e7ef4e23 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 7 Feb 2024 16:37:35 +0000 Subject: [PATCH 02/50] add case for TF prep in prep_model_weights --- cellfinder/core/tools/prep.py | 14 +++++++++++--- cellfinder/core/tools/tf.py | 6 ++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/cellfinder/core/tools/prep.py b/cellfinder/core/tools/prep.py index 0e7e7217..dfea225a 100644 --- a/cellfinder/core/tools/prep.py +++ b/cellfinder/core/tools/prep.py @@ -3,10 +3,12 @@ ================== Functions to prepare files and directories needed for other functions """ + import os from pathlib import Path from typing import Optional +import keras from brainglobe_utils.general.config import get_config_obj from brainglobe_utils.general.system import get_num_processes @@ -26,8 +28,13 @@ def prep_model_weights( model_name: model_download.model_type, n_free_cpus: int, ) -> Path: - n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) - prep_tensorflow(n_processes) + # if TF backend: + if keras.config.backend() == "tensorflow": + # prep TF + n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) + prep_tensorflow(n_processes) + + # prep models (get default weights or provided ones?) model_weights = prep_models(model_weights, install_path, model_name) return model_weights @@ -44,7 +51,8 @@ def prep_models( model_name: model_download.model_type, ) -> Path: install_path = install_path or DEFAULT_INSTALL_PATH - # if no model or weights, set default weights + + # if no model or weights, set to default weights if model_weights_path is None: logger.debug("No model supplied, so using the default") diff --git a/cellfinder/core/tools/tf.py b/cellfinder/core/tools/tf.py index 778aa78f..b50067f6 100644 --- a/cellfinder/core/tools/tf.py +++ b/cellfinder/core/tools/tf.py @@ -1,5 +1,3 @@ -import tensorflow as tf - from cellfinder.core import logger @@ -9,6 +7,8 @@ def allow_gpu_memory_growth(): away. Allows multiple processes to use the GPU (and avoid occasional errors on some systems) at the cost of a slight performance penalty. """ + import tensorflow as tf # import inside fns? + gpus = tf.config.experimental.list_physical_devices("GPU") if gpus: logger.debug("Allowing GPU memory growth") @@ -38,6 +38,8 @@ def set_tf_threads(max_threads): f"to: {max_threads}" ) + import tensorflow as tf # import inside fns? + # If statements are for testing. If tf is initialised, then setting these # parameters throws an error if tf.config.threading.get_inter_op_parallelism_threads() != 0: From 559c9ab397207e59f150ea1c10a2094a89eed125 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:01:42 +0000 Subject: [PATCH 03/50] add different backends to pyproject.toml --- pyproject.toml | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5c73d68c..2ca0c9d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,8 +29,7 @@ dependencies = [ "numpy", "scikit-image", "scikit-learn", - "keras", - "tf-nightly==2.16.0.dev20240101", # pinning to same TF as Keras 3.0 + "keras==3.0.0", "tifffile", "tqdm", ] @@ -60,6 +59,16 @@ napari = [ "pooch >= 1", "qtpy", ] +tf_backend = [ + "tf-nightly==2.16.0.dev20240101", # pinning to same TF as Keras 3.0 --> tensorflow==2.15.0 and keras==3.0.0 +] +jax_backend = [ + "jax==0.4.20", + "jaxlib==0.4.20" +] +torch_backend = [ + "torch==2.1.0" +] [project.scripts] cellfinder_download = "cellfinder.core.download.cli:main" From d3cb209fa57125e8f29d6730f93bff0c6b3579e1 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 7 Feb 2024 18:10:49 +0000 Subject: [PATCH 04/50] add backend configuration to cellfinder init file. tests passing with jax locally --- cellfinder/__init__.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index 1b9c1505..84b2b242 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -5,7 +5,7 @@ except PackageNotFoundError as e: raise PackageNotFoundError("cellfinder package not installed") from e -# If Keras is not present with a backend, tools cannot be used. +# If Keras is not present, tools cannot be used. # Throw an error in this case to prevent invocation of functions. try: KERAS_VERSION = version("keras") @@ -17,5 +17,21 @@ f"https://github.com/brainglobe/brainglobe-meta#readme." ) from e +# Configure Keras backend: +# Note that Keras should only be imported after the backend +# has been configured. The backend cannot be changed once the +# package is imported. +# https://keras.io/getting_started/intro_to_keras_for_engineers/ +# https://github.com/keras-team/keras/blob/5bc8488c0ea3f43c70c70ebca919093cd56066eb/keras/backend/config.py#L263 +try: + import os + + # check if environment variable exists? + os.environ["KERAS_BACKEND"] = "jax" # "torch" "jax", "tensorflow" + +except PackageNotFoundError as e: + raise PackageNotFoundError("error setting up Keras backend") from e + + __author__ = "Adam Tyson, Christian Niedworok, Charly Rousseau" __license__ = "BSD-3-Clause" From c8e0ac8ca02a33625e690bfca045bf3b5a7b0ba6 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 8 Feb 2024 17:05:02 +0000 Subject: [PATCH 05/50] define extra dependencies for cellfinder with different backends. run tox with TF backend --- cellfinder/__init__.py | 7 +++++-- pyproject.toml | 10 ++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index 84b2b242..31945b5d 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -26,8 +26,11 @@ try: import os - # check if environment variable exists? - os.environ["KERAS_BACKEND"] = "jax" # "torch" "jax", "tensorflow" + # check if environment variable exists, otherwise set to tensorflow? + if not os.getenv("KERAS_BACKEND"): + os.environ[ + "KERAS_BACKEND" + ] = "tensorflow" # "torch" "jax", "tensorflow" except PackageNotFoundError as e: raise PackageNotFoundError("error setting up Keras backend") from e diff --git a/pyproject.toml b/pyproject.toml index 2ca0c9d8..5a1f4361 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,14 +59,14 @@ napari = [ "pooch >= 1", "qtpy", ] -tf_backend = [ - "tf-nightly==2.16.0.dev20240101", # pinning to same TF as Keras 3.0 --> tensorflow==2.15.0 and keras==3.0.0 +tf-backend = [ + "tf-nightly==2.16.0.dev20240101", # pinning to same TF as Keras 3.0 ] -jax_backend = [ +jax-backend = [ "jax==0.4.20", "jaxlib==0.4.20" ] -torch_backend = [ +torch-backend = [ "torch==2.1.0" ] @@ -140,6 +140,7 @@ deps = pytest-qt extras = napari + tf-backend passenv = NUMBA_DISABLE_JIT CI @@ -148,4 +149,5 @@ passenv = XAUTHORITY NUMPY_EXPERIMENTAL_ARRAY_FUNCTION PYVISTA_OFF_SCREEN + KERAS_BACKEND """ From 313b988923506a5c9d79ec826a99b86d6553e07b Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 8 Feb 2024 18:05:54 +0000 Subject: [PATCH 06/50] run tox using TF and JAX backend --- pyproject.toml | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a1f4361..805ae8da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,13 +119,13 @@ markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] legacy_tox_ini = """ # For more information about tox, see https://tox.readthedocs.io/en/latest/ [tox] -envlist = py{39,310} +envlist = py{39,310}-{tf,jax} isolated_build = true [gh-actions] python = - 3.9: py39 - 3.10: py310 + 3.9: py39-{tf,jax} # On GA python=3.9 job, run tox with the tf and jax environments + 3.10: py310-{tf,jax} # On GA python=3.10 job, run tox with the tf and jax environments [testenv] commands = python -m pytest -v --color=yes @@ -140,7 +140,11 @@ deps = pytest-qt extras = napari - tf-backend + tf: tf-backend + jax: jax-backend +setenv = + tf: KERAS_BACKEND = tensorflow + jax: KERAS_BACKEND = jax passenv = NUMBA_DISABLE_JIT CI @@ -149,5 +153,4 @@ passenv = XAUTHORITY NUMPY_EXPERIMENTAL_ARRAY_FUNCTION PYVISTA_OFF_SCREEN - KERAS_BACKEND """ From 681925e02064852d03a8e83a762e7ccf4b741b98 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 8 Feb 2024 18:35:26 +0000 Subject: [PATCH 07/50] install TF in brainmapper environment before running tests in CI --- .github/workflows/test_and_deploy.yml | 4 +++- cellfinder/__init__.py | 10 ++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 54312332..94a29673 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -115,8 +115,10 @@ jobs: - name: Install test dependencies run: | python -m pip install --upgrade pip wheel - # Install latest SHA on this brainglobe-workflows branch + # Install cellfinder from the latest SHA on this branch python -m pip install git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA + # Install tensorflow as keras' default backend + python -m pip install "tf-nightly==2.16.0.dev20240101" # Install checked out copy of brainglobe-workflows python -m pip install .[dev] diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index 31945b5d..ec4aa0bd 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -21,16 +21,14 @@ # Note that Keras should only be imported after the backend # has been configured. The backend cannot be changed once the # package is imported. -# https://keras.io/getting_started/intro_to_keras_for_engineers/ -# https://github.com/keras-team/keras/blob/5bc8488c0ea3f43c70c70ebca919093cd56066eb/keras/backend/config.py#L263 try: import os - # check if environment variable exists, otherwise set to tensorflow? + # if environment variable does not exist, assign TF + # options: "torch" "jax", "tensorflow" if not os.getenv("KERAS_BACKEND"): - os.environ[ - "KERAS_BACKEND" - ] = "tensorflow" # "torch" "jax", "tensorflow" + os.environ["KERAS_BACKEND"] = "tensorflow" + except PackageNotFoundError as e: raise PackageNotFoundError("error setting up Keras backend") from e From 0ff45468c0c169bdd1675007cdfb4473138f1813 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 8 Feb 2024 19:07:11 +0000 Subject: [PATCH 08/50] add backends check to cellfinder init file --- cellfinder/__init__.py | 45 +++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index ec4aa0bd..bceb4403 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -1,5 +1,8 @@ +import os +import warnings from importlib.metadata import PackageNotFoundError, version +# Check cellfinder is installed try: __version__ = version("cellfinder") except PackageNotFoundError as e: @@ -12,26 +15,36 @@ except PackageNotFoundError as e: raise PackageNotFoundError( f"cellfinder tools cannot be invoked without Keras. " - f"Please install Keras into your environment to use cellfinder tools. " - f"For more information, please see " + f"Please install Keras with a backend into your environment " + f"to use cellfinder tools. " + f"For more information on Keras backends, please see " + f"https://keras.io/getting_started/#installing-keras-3." + f"For more information on brainglobe, please see " f"https://github.com/brainglobe/brainglobe-meta#readme." ) from e -# Configure Keras backend: -# Note that Keras should only be imported after the backend -# has been configured. The backend cannot be changed once the -# package is imported. -try: - import os - - # if environment variable does not exist, assign TF - # options: "torch" "jax", "tensorflow" - if not os.getenv("KERAS_BACKEND"): - os.environ["KERAS_BACKEND"] = "tensorflow" - -except PackageNotFoundError as e: - raise PackageNotFoundError("error setting up Keras backend") from e +# If no backend is configured and installed for Keras, tools cannot be used +# Check backend is configured +if not os.getenv("KERAS_BACKEND"): + os.environ["KERAS_BACKEND"] = "tensorflow" + warnings.warn( + "Keras backend not configured, automatically set to Tensorflow" + ) + +# Check backend is installed +if os.getenv("KERAS_BACKEND") in ["tensorflow", "jax", "torch"]: + backend = os.getenv("KERAS_BACKEND") + try: + BACKEND_VERSION = version(backend) + except PackageNotFoundError as e: + raise PackageNotFoundError( + f"{backend} package set as Keras backend but not installed" + ) from e +else: + raise PackageNotFoundError( + "Keras backend must be one of 'tensorflow', 'jax', or 'torch'" + ) __author__ = "Adam Tyson, Christian Niedworok, Charly Rousseau" From 6ec3c33ef8034b97d3c224a5857ee5a76326b9df Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 8 Feb 2024 19:26:25 +0000 Subject: [PATCH 09/50] clean up comments --- cellfinder/core/classify/resnet.py | 2 +- cellfinder/core/tools/prep.py | 5 ++--- cellfinder/core/tools/tf.py | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/cellfinder/core/classify/resnet.py b/cellfinder/core/classify/resnet.py index fb68dbdc..d25d7dae 100644 --- a/cellfinder/core/classify/resnet.py +++ b/cellfinder/core/classify/resnet.py @@ -1,7 +1,7 @@ from typing import Callable, Dict, List, Literal, Optional, Tuple, Union from keras import ( - KerasTensor as Tensor, # from tensorflow import Tensor # tf.Tensor + KerasTensor as Tensor, ) from keras import Model from keras.initializers import Initializer diff --git a/cellfinder/core/tools/prep.py b/cellfinder/core/tools/prep.py index dfea225a..0f614890 100644 --- a/cellfinder/core/tools/prep.py +++ b/cellfinder/core/tools/prep.py @@ -28,13 +28,12 @@ def prep_model_weights( model_name: model_download.model_type, n_free_cpus: int, ) -> Path: - # if TF backend: + # if tensorflow backend: do required prep if keras.config.backend() == "tensorflow": - # prep TF n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) prep_tensorflow(n_processes) - # prep models (get default weights or provided ones?) + # prepare models (get default weights or provided ones) model_weights = prep_models(model_weights, install_path, model_name) return model_weights diff --git a/cellfinder/core/tools/tf.py b/cellfinder/core/tools/tf.py index b50067f6..7691875c 100644 --- a/cellfinder/core/tools/tf.py +++ b/cellfinder/core/tools/tf.py @@ -7,7 +7,7 @@ def allow_gpu_memory_growth(): away. Allows multiple processes to use the GPU (and avoid occasional errors on some systems) at the cost of a slight performance penalty. """ - import tensorflow as tf # import inside fns? + import tensorflow as tf gpus = tf.config.experimental.list_physical_devices("GPU") if gpus: @@ -38,7 +38,7 @@ def set_tf_threads(max_threads): f"to: {max_threads}" ) - import tensorflow as tf # import inside fns? + import tensorflow as tf # If statements are for testing. If tf is initialised, then setting these # parameters throws an error From f4857e1d2611484a1bfb231a058f78fc9f9be75f Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 8 Feb 2024 19:26:36 +0000 Subject: [PATCH 10/50] fix tf-nightly import check --- cellfinder/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index bceb4403..eb1a8bf4 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -36,10 +36,12 @@ if os.getenv("KERAS_BACKEND") in ["tensorflow", "jax", "torch"]: backend = os.getenv("KERAS_BACKEND") try: - BACKEND_VERSION = version(backend) + backend_package = "tf-nightly" if backend == "tensorflow" else backend + BACKEND_VERSION = version(backend_package) except PackageNotFoundError as e: raise PackageNotFoundError( - f"{backend} package set as Keras backend but not installed" + f"{backend} package ({backend_package}) set as Keras backend " + f"but not installed" ) from e else: raise PackageNotFoundError( From b32fab4f0fe67e7735480142424664841d4dc863 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 8 Feb 2024 19:30:01 +0000 Subject: [PATCH 11/50] specify TF backend in include guard check --- .github/workflows/test_include_guard.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_include_guard.yaml b/.github/workflows/test_include_guard.yaml index 970421ca..926763b8 100644 --- a/.github/workflows/test_include_guard.yaml +++ b/.github/workflows/test_include_guard.yaml @@ -24,8 +24,8 @@ jobs: with: python-version: '3.10' - - name: Install via pip - run: python -m pip install -e . + - name: Install via pip using tensorflow backend + run: python -m pip install -e ".[tf-backend]" - name: Test (working) import uses: jannekem/run-python-script-action@v1 From 8bfa0d9ff848472ddf6afe18076f0707e7ac09bf Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 8 Feb 2024 19:31:53 +0000 Subject: [PATCH 12/50] clarify comment --- .github/workflows/test_include_guard.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_include_guard.yaml b/.github/workflows/test_include_guard.yaml index 926763b8..9023c57d 100644 --- a/.github/workflows/test_include_guard.yaml +++ b/.github/workflows/test_include_guard.yaml @@ -24,7 +24,7 @@ jobs: with: python-version: '3.10' - - name: Install via pip using tensorflow backend + - name: Install cellfinder via pip, specifying tensorflow as keras' backend run: python -m pip install -e ".[tf-backend]" - name: Test (working) import From 576340bf74997e50838b6f3946a3793dba0210dc Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Fri, 9 Feb 2024 16:12:46 +0000 Subject: [PATCH 13/50] remove 'backend' from dependencies specifications --- .github/workflows/test_include_guard.yaml | 2 +- pyproject.toml | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test_include_guard.yaml b/.github/workflows/test_include_guard.yaml index 9023c57d..be8609fd 100644 --- a/.github/workflows/test_include_guard.yaml +++ b/.github/workflows/test_include_guard.yaml @@ -25,7 +25,7 @@ jobs: python-version: '3.10' - name: Install cellfinder via pip, specifying tensorflow as keras' backend - run: python -m pip install -e ".[tf-backend]" + run: python -m pip install -e ".[tf]" - name: Test (working) import uses: jannekem/run-python-script-action@v1 diff --git a/pyproject.toml b/pyproject.toml index 805ae8da..4292c00d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,14 +59,15 @@ napari = [ "pooch >= 1", "qtpy", ] -tf-backend = [ +# Keras backends +tf = [ "tf-nightly==2.16.0.dev20240101", # pinning to same TF as Keras 3.0 ] -jax-backend = [ +jax = [ "jax==0.4.20", "jaxlib==0.4.20" ] -torch-backend = [ +torch = [ "torch==2.1.0" ] @@ -140,8 +141,8 @@ deps = pytest-qt extras = napari - tf: tf-backend - jax: jax-backend + tf: tf + jax: jax setenv = tf: KERAS_BACKEND = tensorflow jax: KERAS_BACKEND = jax From 89339fabfa25a4c9cd9db155ea4975b21fdeda94 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Fri, 9 Feb 2024 16:15:49 +0000 Subject: [PATCH 14/50] Apply suggestions from code review Co-authored-by: Igor Tatarnikov <61896994+IgorTatarnikov@users.noreply.github.com> --- cellfinder/__init__.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index eb1a8bf4..7ea9825b 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -40,7 +40,7 @@ BACKEND_VERSION = version(backend_package) except PackageNotFoundError as e: raise PackageNotFoundError( - f"{backend} package ({backend_package}) set as Keras backend " + f"{backend}, ({backend_package}) set as Keras backend " f"but not installed" ) from e else: diff --git a/pyproject.toml b/pyproject.toml index 4292c00d..6dd1bc41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "numpy", "scikit-image", "scikit-learn", - "keras==3.0.0", + "keras>=3.0.0", "tifffile", "tqdm", ] From 01394af608afe070f794a74946622c3d3714dd4c Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Wed, 14 Feb 2024 17:10:44 +0000 Subject: [PATCH 15/50] PyTorch runs utilizing multiple cores --- cellfinder/core/train/train_yml.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/cellfinder/core/train/train_yml.py b/cellfinder/core/train/train_yml.py index dd83ebc5..118fcbe3 100644 --- a/cellfinder/core/train/train_yml.py +++ b/cellfinder/core/train/train_yml.py @@ -22,7 +22,10 @@ check_positive_float, check_positive_int, ) -from brainglobe_utils.general.system import ensure_directory_exists +from brainglobe_utils.general.system import ( + ensure_directory_exists, + get_num_processes, +) from brainglobe_utils.IO.cells import find_relevant_tiffs from brainglobe_utils.IO.yaml import read_yaml_section from fancylog import fancylog @@ -338,7 +341,7 @@ def run( ensure_directory_exists(output_dir) model_weights = prep_model_weights( - install_path, model_weights, model, n_free_cpus + model_weights, install_path, model, n_free_cpus ) yaml_contents = parse_yaml(yaml_file) @@ -360,6 +363,7 @@ def run( signal_train, background_train, labels_train = make_lists(tiff_files) + n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) if test_fraction > 0: logger.info("Splitting data into training and validation datasets") ( @@ -386,7 +390,8 @@ def run( labels=labels_test, batch_size=batch_size, train=True, - use_multiprocessing=False, + use_multiprocessing=True, + workers=n_processes, ) # for saving checkpoints @@ -405,7 +410,8 @@ def run( shuffle=True, train=True, augment=not no_augment, - use_multiprocessing=False, + use_multiprocessing=True, + workers=n_processes, ) callbacks = [] From 22db8f48b55f8124a271cfe3a19fb57457320c27 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Fri, 5 Apr 2024 14:32:11 +0100 Subject: [PATCH 16/50] PyTorch fix with default models --- cellfinder/core/classify/resnet.py | 5 +++++ pyproject.toml | 6 ++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/cellfinder/core/classify/resnet.py b/cellfinder/core/classify/resnet.py index d25d7dae..e0bc98d4 100644 --- a/cellfinder/core/classify/resnet.py +++ b/cellfinder/core/classify/resnet.py @@ -1,5 +1,6 @@ from typing import Callable, Dict, List, Literal, Optional, Tuple, Union +import keras.config from keras import ( KerasTensor as Tensor, ) @@ -133,6 +134,10 @@ def non_residual_block( )(x) x = BatchNormalization(axis=axis, epsilon=bn_epsilon, name="conv1_bn")(x) x = Activation(activation, name="conv1_activation")(x) + + if keras.config.backend() == "torch": + pooling_padding = "valid" + x = MaxPooling3D( max_pool_size, strides=strides, diff --git a/pyproject.toml b/pyproject.toml index 6dd1bc41..4a13c000 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,8 +125,8 @@ isolated_build = true [gh-actions] python = - 3.9: py39-{tf,jax} # On GA python=3.9 job, run tox with the tf and jax environments - 3.10: py310-{tf,jax} # On GA python=3.10 job, run tox with the tf and jax environments + 3.9: py39-{tf,jax,torch} # On GA python=3.9 job, run tox with the tf and jax environments + 3.10: py310-{tf,jax,torch} # On GA python=3.10 job, run tox with the tf and jax environments [testenv] commands = python -m pytest -v --color=yes @@ -143,9 +143,11 @@ extras = napari tf: tf jax: jax + torch: torch setenv = tf: KERAS_BACKEND = tensorflow jax: KERAS_BACKEND = jax + torch: KERAS_BACKEND = torch passenv = NUMBA_DISABLE_JIT CI From 64bde716c0bb048a590d1f192ecf9de77d8c60d2 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Fri, 5 Apr 2024 14:35:30 +0100 Subject: [PATCH 17/50] Tests run on every push for now --- .github/workflows/test_and_deploy.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 94a29673..ab71b5d4 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -4,9 +4,9 @@ on: # Only run on pushes to main, or when version tags are pushed push: branches: - - "main" + - "*" tags: - - "v**" + - "*" # Run on all pull-requests pull_request: # Allow workflow dispatch from GitHub From 2934f35e736dd805fe5ac1863eb3f8bbdfa88027 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Fri, 5 Apr 2024 14:58:39 +0100 Subject: [PATCH 18/50] Run test on torch backend only --- .github/workflows/test_and_deploy.yml | 6 +++--- .github/workflows/test_include_guard.yaml | 4 ++-- pyproject.toml | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 0f1325cd..36344cdf 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -4,9 +4,9 @@ on: # Only run on pushes to main, or when version tags are pushed push: branches: - - "*" + - "main" tags: - - "*" + - "v**" # Run on all pull-requests pull_request: # Allow workflow dispatch from GitHub @@ -120,7 +120,7 @@ jobs: run: | python -m pip install --upgrade pip wheel # Install cellfinder from the latest SHA on this branch (Keras with JAX backend) - python -m pip install "cellfinder[jax] @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA" + python -m pip install "cellfinder[torch] @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA" # Install checked out copy of brainglobe-workflows python -m pip install .[dev] diff --git a/.github/workflows/test_include_guard.yaml b/.github/workflows/test_include_guard.yaml index be8609fd..ff32c4f2 100644 --- a/.github/workflows/test_include_guard.yaml +++ b/.github/workflows/test_include_guard.yaml @@ -24,8 +24,8 @@ jobs: with: python-version: '3.10' - - name: Install cellfinder via pip, specifying tensorflow as keras' backend - run: python -m pip install -e ".[tf]" + - name: Install cellfinder via pip, specifying torch as keras' backend + run: python -m pip install -e ".[torch]" - name: Test (working) import uses: jannekem/run-python-script-action@v1 diff --git a/pyproject.toml b/pyproject.toml index dc0f0e17..61f8a46c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ napari = [ ] # Keras backends tf = [ - "tf-nightly==2.16.0.dev20240101", # pinning to same TF as Keras 3.0 + "tensorflow", # pinning to same TF as Keras 3.0 ] jax = [ "jax==0.4.20", @@ -120,13 +120,13 @@ markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] legacy_tox_ini = """ # For more information about tox, see https://tox.readthedocs.io/en/latest/ [tox] -envlist = py{39,310}-{tf,jax} +envlist = py{39,310}-{torch} isolated_build = true [gh-actions] python = - 3.9: py39-{tf,jax,torch} # On GA python=3.9 job, run tox with the tf and jax environments - 3.10: py310-{tf,jax,torch} # On GA python=3.10 job, run tox with the tf and jax environments + 3.9: py39-{torch} # On GA python=3.9 job, run tox with the tf and jax environments + 3.10: py310-{torch} # On GA python=3.10 job, run tox with the tf and jax environments [testenv] From 014b5491fa16dfb56c262654eddd440cc8d635a6 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Fri, 5 Apr 2024 16:12:32 +0100 Subject: [PATCH 19/50] Fixed guard test to set torch as KERAS_BACKEND --- .github/workflows/test_include_guard.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_include_guard.yaml b/.github/workflows/test_include_guard.yaml index ff32c4f2..4ef96b3f 100644 --- a/.github/workflows/test_include_guard.yaml +++ b/.github/workflows/test_include_guard.yaml @@ -25,7 +25,9 @@ jobs: python-version: '3.10' - name: Install cellfinder via pip, specifying torch as keras' backend - run: python -m pip install -e ".[torch]" + run: | + python -m pip install -e ".[torch]" + KERAS_BACKEND="torch" - name: Test (working) import uses: jannekem/run-python-script-action@v1 From 315dbc4d1dfe3fe5a3fe2eb2946903d505885946 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Fri, 5 Apr 2024 16:18:51 +0100 Subject: [PATCH 20/50] KERAS_BACKEND env variable set directly in test_include_guard.yaml --- .github/workflows/test_include_guard.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_include_guard.yaml b/.github/workflows/test_include_guard.yaml index 4ef96b3f..f080e210 100644 --- a/.github/workflows/test_include_guard.yaml +++ b/.github/workflows/test_include_guard.yaml @@ -25,12 +25,12 @@ jobs: python-version: '3.10' - name: Install cellfinder via pip, specifying torch as keras' backend - run: | - python -m pip install -e ".[torch]" - KERAS_BACKEND="torch" + run: python -m pip install -e ".[torch]" - name: Test (working) import uses: jannekem/run-python-script-action@v1 + env: + KERAS_BACKEND: torch with: fail-on-error: true script: | From 4590106d8239ef9bd21274689ade5b1377151a97 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Fri, 5 Apr 2024 16:41:40 +0100 Subject: [PATCH 21/50] Run test on python 3.11 --- .github/workflows/test_and_deploy.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 36344cdf..26ee8447 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -40,14 +40,14 @@ jobs: strategy: matrix: # Run all supported Python versions on linux - os: [ubuntu-latest] - python-version: ["3.9", "3.10"] - # Include one macos run - include: - - os: macos-latest - python-version: "3.10" - - os: windows-latest - python-version: "3.10" + os: [ubuntu-latest,macos-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11"] +# # Include one macos run +# include: +# - os: macos-latest +# python-version: "3.10" +# - os: windows-latest +# python-version: "3.10" steps: # Cache the Keras model so we don't have to remake it every time From 0ddb2e074ae2e857a00493fb0c321f6091d4c9e0 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Fri, 5 Apr 2024 16:58:28 +0100 Subject: [PATCH 22/50] Remove tf-nightly from __init__ version check --- cellfinder/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index 7ea9825b..8fe562d0 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -36,11 +36,10 @@ if os.getenv("KERAS_BACKEND") in ["tensorflow", "jax", "torch"]: backend = os.getenv("KERAS_BACKEND") try: - backend_package = "tf-nightly" if backend == "tensorflow" else backend - BACKEND_VERSION = version(backend_package) + BACKEND_VERSION = version(backend) except PackageNotFoundError as e: raise PackageNotFoundError( - f"{backend}, ({backend_package}) set as Keras backend " + f"{backend}, ({backend}) set as Keras backend " f"but not installed" ) from e else: From fe7b79891a8363349997ab692c965a59a9f8b19f Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Fri, 5 Apr 2024 17:06:42 +0100 Subject: [PATCH 23/50] Added 3.11 to legacy tox config --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5741103e..866f7037 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,13 +120,14 @@ markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] legacy_tox_ini = """ # For more information about tox, see https://tox.readthedocs.io/en/latest/ [tox] -envlist = py{39,310}-{torch} +envlist = py{39,310,311}-{torch} isolated_build = true [gh-actions] python = 3.9: py39-{torch} # On GA python=3.9 job, run tox with the tf and jax environments 3.10: py310-{torch} # On GA python=3.10 job, run tox with the tf and jax environments + 3.10: py310-{torch} # On GA python=3.11 job, run tox with the tf and jax environments [testenv] From dcb315c93e976231c05666bddcf2bdcfda590002 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Fri, 5 Apr 2024 17:14:18 +0100 Subject: [PATCH 24/50] Changed legacy tox config for real this time --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 866f7037..16a99c68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,7 +127,7 @@ isolated_build = true python = 3.9: py39-{torch} # On GA python=3.9 job, run tox with the tf and jax environments 3.10: py310-{torch} # On GA python=3.10 job, run tox with the tf and jax environments - 3.10: py310-{torch} # On GA python=3.11 job, run tox with the tf and jax environments + 3.11: py311-{torch} # On GA python=3.11 job, run tox with the tf and jax environments [testenv] From 8e77b68f4d0571e2e10fbf39b198885e70b300fd Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Mon, 15 Apr 2024 11:50:38 +0100 Subject: [PATCH 25/50] Don't set the wrong max_processing value --- cellfinder/core/classify/classify.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cellfinder/core/classify/classify.py b/cellfinder/core/classify/classify.py index 2193ac08..0acf10ca 100644 --- a/cellfinder/core/classify/classify.py +++ b/cellfinder/core/classify/classify.py @@ -48,9 +48,7 @@ def main( callbacks = None # Too many workers doesn't increase speed, and uses huge amounts of RAM - workers = get_num_processes( - min_free_cpu_cores=n_free_cpus, n_max_processes=max_workers - ) + workers = get_num_processes(min_free_cpu_cores=n_free_cpus) logger.debug("Initialising cube generator") inference_generator = CubeGeneratorFromFile( From 30b72f1820288ac0d8ae3098de4c2c5b2efda790 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Tue, 16 Apr 2024 15:24:18 +0100 Subject: [PATCH 26/50] Torch is now set as the default backend --- cellfinder/__init__.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index 8fe562d0..700a4129 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -27,10 +27,8 @@ # If no backend is configured and installed for Keras, tools cannot be used # Check backend is configured if not os.getenv("KERAS_BACKEND"): - os.environ["KERAS_BACKEND"] = "tensorflow" - warnings.warn( - "Keras backend not configured, automatically set to Tensorflow" - ) + os.environ["KERAS_BACKEND"] = "torch" + warnings.warn("Keras backend not configured, automatically set to Torch") # Check backend is installed if os.getenv("KERAS_BACKEND") in ["tensorflow", "jax", "torch"]: @@ -44,7 +42,7 @@ ) from e else: raise PackageNotFoundError( - "Keras backend must be one of 'tensorflow', 'jax', or 'torch'" + "Keras backend must be one of 'torch', 'tensorflow', or 'jax'" ) From 1de446c8b2e113fb03e30a6396d3f5573fde2f03 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Tue, 16 Apr 2024 16:11:49 +0100 Subject: [PATCH 27/50] Tests only run with torch, updated comments --- .github/workflows/test_and_deploy.yml | 18 +++++++++--------- pyproject.toml | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index ddfb0cc7..e375e41f 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -40,14 +40,14 @@ jobs: strategy: matrix: # Run all supported Python versions on linux - os: [ubuntu-latest,macos-latest, windows-latest] + os: [ubuntu-latest] python-version: ["3.9", "3.10", "3.11"] -# # Include one macos run -# include: -# - os: macos-latest -# python-version: "3.10" -# - os: windows-latest -# python-version: "3.10" + # Include one macos run + include: + - os: macos-latest + python-version: "3.10" + - os: windows-latest + python-version: "3.10" steps: # Cache the Keras model so we don't have to remake it every time @@ -98,7 +98,7 @@ jobs: name: Run brainmapper tests to check for breakages runs-on: ubuntu-latest env: - KERAS_BACKEND: jax + KERAS_BACKEND: torch steps: - name: Cache Keras model uses: actions/cache@v3 @@ -119,7 +119,7 @@ jobs: - name: Install test dependencies run: | python -m pip install --upgrade pip wheel - # Install cellfinder from the latest SHA on this branch (Keras with JAX backend) + # Install cellfinder from the latest SHA on this branch (Keras with torch backend) python -m pip install "cellfinder[torch] @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA" # Install checked out copy of brainglobe-workflows python -m pip install .[dev] diff --git a/pyproject.toml b/pyproject.toml index e487dfc0..f7122a5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ napari = [ ] # Keras backends tf = [ - "tensorflow", # pinning to same TF as Keras 3.0 + "tensorflow", ] jax = [ "jax==0.4.20", @@ -125,9 +125,9 @@ isolated_build = true [gh-actions] python = - 3.9: py39-{torch} # On GA python=3.9 job, run tox with the tf and jax environments - 3.10: py310-{torch} # On GA python=3.10 job, run tox with the tf and jax environments - 3.11: py311-{torch} # On GA python=3.11 job, run tox with the tf and jax environments + 3.9: py39-{torch} # On GA python=3.9 job, run tox with the torch environment + 3.10: py310-{torch} # On GA python=3.10 job, run tox with the torch environment + 3.11: py311-{torch} # On GA python=3.11 job, run tox with the torch environment [testenv] From e4bd6655bab1eb4d67fb3fba8dbdaacc76339631 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Tue, 16 Apr 2024 16:51:04 +0100 Subject: [PATCH 28/50] Unpinned torch version --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f7122a5b..12bade6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,14 +61,14 @@ napari = [ ] # Keras backends tf = [ - "tensorflow", + "tensorflow>=2.16.1", ] jax = [ "jax==0.4.20", "jaxlib==0.4.20" ] torch = [ - "torch==2.1.0" + "torch>=2.1.0" ] [project.scripts] @@ -88,7 +88,7 @@ requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] build-backend = 'setuptools.build_meta' [tool.black] -target-version = ['py39', 'py310'] +target-version = ['py39', 'py310','py311'] skip-string-normalization = false line-length = 79 From 560909ff6eb74072d970b4752665a29382be368a Mon Sep 17 00:00:00 2001 From: Kimberly Meechan <24316371+K-Meech@users.noreply.github.com> Date: Thu, 18 Apr 2024 10:49:17 +0100 Subject: [PATCH 29/50] Add codecov token (#403) * add codecov token * generate xml coverage report * add timeout to testing jobs --- .github/workflows/test_and_deploy.yml | 4 ++++ pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 1ae8ae98..2bf9e976 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -65,11 +65,13 @@ jobs: - uses: neuroinformatics-unit/actions/test@v2 with: python-version: ${{ matrix.python-version }} + secret-codecov-token: ${{ secrets.CODECOV_TOKEN }} use-xvfb: true test_numba_disabled: needs: [linting, manifest] name: Run tests with numba disabled + timeout-minutes: 60 runs-on: ubuntu-latest env: NUMBA_DISABLE_JIT: "1" @@ -89,6 +91,7 @@ jobs: - uses: neuroinformatics-unit/actions/test@v2 with: python-version: "3.10" + secret-codecov-token: ${{ secrets.CODECOV_TOKEN }} codecov-flags: "numba" # Run brainglobe-workflows brainmapper-CLI tests to check for @@ -96,6 +99,7 @@ jobs: test_brainmapper_cli: needs: [linting, manifest] name: Run brainmapper tests to check for breakages + timeout-minutes: 60 runs-on: ubuntu-latest steps: - name: Cache tensorflow model diff --git a/pyproject.toml b/pyproject.toml index 6e5ac59d..2eeeb61e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,7 @@ python = 3.10: py310 [testenv] -commands = python -m pytest -v --color=yes +commands = python -m pytest -v --color=yes --cov=cellfinder --cov-report=xml deps = pytest pytest-cov From e6a887f1721b89d51328214b108e0da5401a24a9 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Mon, 22 Apr 2024 06:44:02 -0400 Subject: [PATCH 30/50] Allow turning off classification or detection in GUI (#402) * Allow turning off classification or detection in GUI. * Fix test. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor to fix code analysis errors. * Ensure array is always 2d. * Apply suggestions from code review Co-authored-by: Igor Tatarnikov <61896994+IgorTatarnikov@users.noreply.github.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Igor Tatarnikov <61896994+IgorTatarnikov@users.noreply.github.com> --- cellfinder/core/main.py | 92 +++--- cellfinder/napari/detect/detect.py | 264 +++++++++++++----- cellfinder/napari/detect/detect_containers.py | 10 +- cellfinder/napari/detect/thread_worker.py | 14 + cellfinder/napari/utils.py | 100 +++++-- tests/napari/test_utils.py | 59 +++- 6 files changed, 396 insertions(+), 143 deletions(-) diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index cb78cca4..c74a9d44 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple import numpy as np +from brainglobe_utils.cells.cells import Cell from brainglobe_utils.general.logging import suppress_specific_logs from cellfinder.core import logger @@ -42,6 +43,9 @@ def main( cube_height: int = 50, cube_depth: int = 20, network_depth: depth_type = "50", + skip_detection: bool = False, + skip_classification: bool = False, + detected_cells: List[Cell] = None, *, detect_callback: Optional[Callable[[int], None]] = None, classify_callback: Optional[Callable[[int], None]] = None, @@ -65,52 +69,58 @@ def main( from cellfinder.core.detect import detect from cellfinder.core.tools import prep - logger.info("Detecting cell candidates") + if not skip_detection: + logger.info("Detecting cell candidates") - points = detect.main( - signal_array, - start_plane, - end_plane, - voxel_sizes, - soma_diameter, - max_cluster_size, - ball_xy_size, - ball_z_size, - ball_overlap_fraction, - soma_spread_factor, - n_free_cpus, - log_sigma_size, - n_sds_above_mean_thresh, - callback=detect_callback, - ) - - if detect_finished_callback is not None: - detect_finished_callback(points) - - install_path = None - model_weights = prep.prep_model_weights( - model_weights, install_path, model, n_free_cpus - ) - if len(points) > 0: - logger.info("Running classification") - points = classify.main( - points, + points = detect.main( signal_array, - background_array, - n_free_cpus, + start_plane, + end_plane, voxel_sizes, - network_voxel_sizes, - batch_size, - cube_height, - cube_width, - cube_depth, - trained_model, - model_weights, - network_depth, - callback=classify_callback, + soma_diameter, + max_cluster_size, + ball_xy_size, + ball_z_size, + ball_overlap_fraction, + soma_spread_factor, + n_free_cpus, + log_sigma_size, + n_sds_above_mean_thresh, + callback=detect_callback, ) + + if detect_finished_callback is not None: + detect_finished_callback(points) else: - logger.info("No candidates, skipping classification") + points = detected_cells or [] # if None + detect_finished_callback(points) + + if not skip_classification: + install_path = None + model_weights = prep.prep_model_weights( + model_weights, install_path, model, n_free_cpus + ) + if len(points) > 0: + logger.info("Running classification") + points = classify.main( + points, + signal_array, + background_array, + n_free_cpus, + voxel_sizes, + network_voxel_sizes, + batch_size, + cube_height, + cube_width, + cube_depth, + trained_model, + model_weights, + network_depth, + callback=classify_callback, + ) + else: + logger.info("No candidates, skipping classification") + return points diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index eaaab20f..cdf36939 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -1,8 +1,11 @@ +from functools import partial from math import ceil from pathlib import Path -from typing import Optional +from typing import Any, Callable, Dict, Optional, Tuple import napari +import napari.layers +from brainglobe_utils.cells.cells import Cell from magicgui import magicgui from magicgui.widgets import FunctionGui, ProgressBar from napari.utils.notifications import show_info @@ -10,9 +13,11 @@ from cellfinder.core.classify.cube_generator import get_cube_depth_min_max from cellfinder.napari.utils import ( - add_layers, + add_classified_layers, + add_single_layer, cellfinder_header, html_label_widget, + napari_array_to_cells, ) from .detect_containers import ( @@ -32,16 +37,10 @@ MIN_PLANES_ANALYSE = 0 -def detect_widget() -> FunctionGui: - """ - Create a detection plugin GUI. - """ - progress_bar = ProgressBar() - - # options that is filled in from the gui - options = {"signal_image": None, "background_image": None, "viewer": None} - - # signal and background images are separated out from the main magicgui +def get_heavy_widgets( + options: Dict[str, Any] +) -> Tuple[Callable, Callable, Callable]: + # signal and other input are separated out from the main magicgui # parameter selections and are inserted as widget children in their own # sub-containers of the root. Because if these image parameters are # included in the root widget, every time *any* parameter updates, the gui @@ -91,6 +90,140 @@ def background_image_opt( """ options["background_image"] = background_image + @magicgui( + call_button=False, + persist=False, + scrollable=False, + labels=False, + auto_call=True, + ) + def cell_layer_opt( + cell_layer: napari.layers.Points, + ): + """ + magicgui widget for setting the cell layer input when detection is + skipped. + + Parameters + ---------- + cell_layer : napari.layers.Points + If detection is skipped, select the cell layer containing the + detected cells to use for classification + """ + options["cell_layer"] = cell_layer + + return signal_image_opt, background_image_opt, cell_layer_opt + + +def add_heavy_widgets( + root: FunctionGui, + widgets: Tuple[FunctionGui, ...], + new_names: Tuple[str, ...], + insertions: Tuple[str, ...], +) -> None: + for widget, new_name, insertion in zip(widgets, new_names, insertions): + # make it look as if it's directly in the root container + widget.margins = 0, 0, 0, 0 + # the parameters of these widgets are updated using `auto_call` only. + # If False, magicgui passes these as args to root() when the root's + # function runs. But that doesn't list them as args of its function + widget.gui_only = True + root.insert(root.index(insertion) + 1, widget) + getattr(root, widget.name).label = new_name + + +def restore_options_defaults(widget: FunctionGui) -> None: + """ + Restore default widget values. + """ + defaults = { + **DataInputs.defaults(), + **DetectionInputs.defaults(), + **ClassificationInputs.defaults(), + **MiscInputs.defaults(), + } + for name, value in defaults.items(): + if value is not None: # ignore fields with no default + getattr(widget, name).value = value + + +def get_results_callback( + skip_classification: bool, viewer: napari.Viewer +) -> Callable: + """ + Returns the callback that is connected to output of the pipeline. + It returns the detected points that we have to visualize. + """ + if skip_classification: + # after detection w/o classification, everything is unknown + def done_func(points): + add_single_layer( + points, + viewer=viewer, + name="Cell candidates", + cell_type=Cell.UNKNOWN, + ) + + else: + # after classification we have either cell or unknown + def done_func(points): + add_classified_layers( + points, + viewer=viewer, + unknown_name="Rejected", + cell_name="Detected", + ) + + return done_func + + +def find_local_planes( + viewer: napari.Viewer, + voxel_size_z: float, + signal_image: napari.layers.Image, +) -> Tuple[int, int]: + """ + When detecting only locally, it returns the start and end planes to use. + """ + current_plane = viewer.dims.current_step[0] + + # so a reasonable number of cells in the plane are detected + planes_needed = MIN_PLANES_ANALYSE + int( + ceil((CUBE_DEPTH * NETWORK_VOXEL_SIZES[0]) / voxel_size_z) + ) + + start_plane, end_plane = get_cube_depth_min_max( + current_plane, planes_needed + ) + start_plane = max(0, start_plane) + end_plane = min(len(signal_image.data), end_plane) + + return start_plane, end_plane + + +def reraise(e: Exception) -> None: + """Re-raises the exception.""" + raise Exception from e + + +def detect_widget() -> FunctionGui: + """ + Create a detection plugin GUI. + """ + progress_bar = ProgressBar() + + # options that is filled in from the gui + options = { + "signal_image": None, + "background_image": None, + "viewer": None, + "cell_layer": None, + } + + signal_image_opt, background_image_opt, cell_layer_opt = get_heavy_widgets( + options + ) + @magicgui( detection_label=html_label_widget("Cell detection", tag="h3"), **DataInputs.widget_representation(), @@ -109,6 +242,7 @@ def widget( voxel_size_y: float, voxel_size_x: float, detection_options, + skip_detection: bool, soma_diameter: float, ball_xy_size: float, ball_z_size: float, @@ -118,6 +252,7 @@ def widget( soma_spread_factor: float, max_cluster_size: int, classification_options, + skip_classification: bool, trained_model: Optional[Path], use_pre_trained_weights: bool, misc_options, @@ -139,6 +274,10 @@ def widget( Size of your voxels in the y direction (top to bottom) voxel_size_x : float Size of your voxels in the x direction (left to right) + skip_detection : bool + If selected, the detection step is skipped and instead we get the + detected cells from the cell layer below (from a previous + detection run or import) soma_diameter : float The expected in-plane soma diameter (microns) ball_xy_size : float @@ -159,6 +298,9 @@ def widget( should be attempted use_pre_trained_weights : bool Select to use pre-trained model weights + skip_classification : bool + If selected, the classification step is skipped and all cells from + the detection stage are added trained_model : Optional[Path] Trained model file path (home directory (default) -> pretrained weights) @@ -184,24 +326,39 @@ def widget( # cellfinder plugin is fully open and initialized signal_image_opt() background_image_opt() + cell_layer_opt() signal_image = options["signal_image"] - background_image = options["background_image"] - viewer = options["viewer"] - if signal_image is None or background_image is None: + if signal_image is None or options["background_image"] is None: show_info("Both signal and background images must be specified.") return + detected_cells = [] + if skip_detection: + if options["cell_layer"] is None: + show_info( + "Skip detection selected, but no existing cell layer " + "is selected." + ) + return + + # set cells as unknown so that classification will process them + detected_cells = napari_array_to_cells( + options["cell_layer"], Cell.UNKNOWN + ) + data_inputs = DataInputs( signal_image.data, - background_image.data, + options["background_image"].data, voxel_size_z, voxel_size_y, voxel_size_x, ) detection_inputs = DetectionInputs( + skip_detection, + detected_cells, soma_diameter, ball_xy_size, ball_z_size, @@ -215,24 +372,15 @@ def widget( if use_pre_trained_weights: trained_model = None classification_inputs = ClassificationInputs( - use_pre_trained_weights, trained_model + skip_classification, use_pre_trained_weights, trained_model ) - end_plane = len(signal_image.data) if end_plane == 0 else end_plane - if analyse_local: - current_plane = viewer.dims.current_step[0] - - # so a reasonable number of cells in the plane are detected - planes_needed = MIN_PLANES_ANALYSE + int( - ceil((CUBE_DEPTH * NETWORK_VOXEL_SIZES[0]) / voxel_size_z) - ) - - start_plane, end_plane = get_cube_depth_min_max( - current_plane, planes_needed + start_plane, end_plane = find_local_planes( + options["viewer"], voxel_size_z, signal_image ) - start_plane = max(0, start_plane) - end_plane = min(len(signal_image.data), end_plane) + elif not end_plane: + end_plane = len(signal_image.data) misc_inputs = MiscInputs( start_plane, end_plane, n_free_cpus, analyse_local, debug @@ -244,58 +392,34 @@ def widget( classification_inputs, misc_inputs, ) + worker.returned.connect( - lambda points: add_layers(points, viewer=viewer) + get_results_callback(skip_classification, options["viewer"]) ) - # Make sure if the worker emits an error, it is propagated to this # thread - def reraise(e): - raise Exception from e - worker.errored.connect(reraise) + worker.connect_progress_bar_callback(progress_bar) - def update_progress_bar(label: str, max: int, value: int): - progress_bar.label = label - progress_bar.max = max - progress_bar.value = value - - worker.update_progress_bar.connect(update_progress_bar) worker.start() widget.native.layout().insertWidget(0, cellfinder_header()) - @widget.reset_button.changed.connect - def restore_defaults(): - """ - Restore default widget values. - """ - defaults = { - **DataInputs.defaults(), - **DetectionInputs.defaults(), - **ClassificationInputs.defaults(), - **MiscInputs.defaults(), - } - for name, value in defaults.items(): - if value is not None: # ignore fields with no default - getattr(widget, name).value = value + # reset restores defaults + widget.reset_button.changed.connect( + partial(restore_options_defaults, widget) + ) # Insert progress bar before the run and reset buttons - widget.insert(-3, progress_bar) - - # add the signal and background image parameters - # make it look as if it's directly in the root container - signal_image_opt.margins = 0, 0, 0, 0 - # the parameters are updated using `auto_call` only. If False, magicgui - # passes these as args to widget(), which doesn't list them as args - signal_image_opt.gui_only = True - widget.insert(3, signal_image_opt) - widget.signal_image_opt.label = "Signal image" - - background_image_opt.margins = 0, 0, 0, 0 - background_image_opt.gui_only = True - widget.insert(4, background_image_opt) - widget.background_image_opt.label = "Background image" + widget.insert(widget.index("debug") + 1, progress_bar) + + # add the signal and background image etc. + add_heavy_widgets( + widget, + (background_image_opt, signal_image_opt, cell_layer_opt), + ("Background image", "Signal image", "Candidate cell layer"), + ("voxel_size_z", "voxel_size_z", "soma_diameter"), + ) scroll = QScrollArea() scroll.setWidget(widget._widget._qwidget) diff --git a/cellfinder/napari/detect/detect_containers.py b/cellfinder/napari/detect/detect_containers.py index 824a2a0b..39fda163 100644 --- a/cellfinder/napari/detect/detect_containers.py +++ b/cellfinder/napari/detect/detect_containers.py @@ -1,8 +1,9 @@ from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy +from brainglobe_utils.cells.cells import Cell from cellfinder.napari.input_container import InputContainer from cellfinder.napari.utils import html_label_widget @@ -59,6 +60,8 @@ def widget_representation(cls) -> dict: class DetectionInputs(InputContainer): """Container for cell candidate detection inputs.""" + skip_detection: bool = False + detected_cells: Optional[List[Cell]] = None soma_diameter: float = 16.0 ball_xy_size: float = 6 ball_z_size: float = 15 @@ -75,6 +78,7 @@ def as_core_arguments(self) -> dict: def widget_representation(cls) -> dict: return dict( detection_options=html_label_widget("Detection:"), + skip_detection=dict(value=cls.defaults()["skip_detection"]), soma_diameter=cls._custom_widget("soma_diameter"), ball_xy_size=cls._custom_widget( "ball_xy_size", custom_label="Ball filter (xy)" @@ -107,6 +111,7 @@ def widget_representation(cls) -> dict: class ClassificationInputs(InputContainer): """Container for classification inputs.""" + skip_classification: bool = False use_pre_trained_weights: bool = True trained_model: Optional[Path] = Path.home() @@ -123,6 +128,9 @@ def widget_representation(cls) -> dict: value=cls.defaults()["use_pre_trained_weights"] ), trained_model=dict(value=cls.defaults()["trained_model"]), + skip_classification=dict( + value=cls.defaults()["skip_classification"] + ), ) diff --git a/cellfinder/napari/detect/thread_worker.py b/cellfinder/napari/detect/thread_worker.py index ea44dded..c4392860 100644 --- a/cellfinder/napari/detect/thread_worker.py +++ b/cellfinder/napari/detect/thread_worker.py @@ -1,3 +1,4 @@ +from magicgui.widgets import ProgressBar from napari.qt.threading import WorkerBase, WorkerBaseSignals from qtpy.QtCore import Signal @@ -41,6 +42,19 @@ def __init__( self.classification_inputs = classification_inputs self.misc_inputs = misc_inputs + def connect_progress_bar_callback(self, progress_bar: ProgressBar): + """ + Connects the progress bar to the work so that updates are shown on + the bar. + """ + + def update_progress_bar(label: str, max: int, value: int): + progress_bar.label = label + progress_bar.max = max + progress_bar.value = value + + self.update_progress_bar.connect(update_progress_bar) + def work(self) -> list: self.update_progress_bar.emit("Setting up detection...", 1, 0) diff --git a/cellfinder/napari/utils.py b/cellfinder/napari/utils.py index d48b81fb..2f853fe0 100644 --- a/cellfinder/napari/utils.py +++ b/cellfinder/napari/utils.py @@ -1,8 +1,8 @@ from typing import List, Tuple import napari +import napari.layers import numpy as np -import pandas as pd from brainglobe_utils.cells.cells import Cell from brainglobe_utils.qtpy.logo import header_widget @@ -31,16 +31,28 @@ def cellfinder_header(): ) -def add_layers(points: List[Cell], viewer: napari.Viewer) -> None: +# the xyz axis order in napari relative to ours. I.e. our zeroth axis is the +# napari last axis. Ours is XYZ. +napari_points_axis_order = 2, 1, 0 +# the xyz axis order in brainglobe relative to napari. I.e. napari's zeroth +# axis is our last axis - it's just flipped +brainglobe_points_axis_order = napari_points_axis_order + + +def add_classified_layers( + points: List[Cell], + viewer: napari.Viewer, + unknown_name: str = "Rejected", + cell_name: str = "Detected", +) -> None: """ - Adds classified cell candidates as two separate point layers to the napari - viewer. + Adds cell candidates as two separate point layers - unknowns and cells, to + the napari viewer. Does not add any other cell types, only Cell.UNKNOWN + and Cell.CELL from the list of cells. """ - detected, rejected = cells_to_array(points) - viewer.add_points( - rejected, - name="Rejected", + cells_to_array(points, Cell.UNKNOWN, napari_order=True), + name=unknown_name, size=15, n_dimensional=True, opacity=0.6, @@ -50,8 +62,8 @@ def add_layers(points: List[Cell], viewer: napari.Viewer) -> None: metadata=dict(point_type=Cell.UNKNOWN), ) viewer.add_points( - detected, - name="Detected", + cells_to_array(points, Cell.CELL, napari_order=True), + name=cell_name, size=15, n_dimensional=True, opacity=0.6, @@ -61,23 +73,61 @@ def add_layers(points: List[Cell], viewer: napari.Viewer) -> None: ) -def cells_df_as_np( - cells_df: pd.DataFrame, - new_order: List[int] = [2, 1, 0], - type_column: str = "type", +def add_single_layer( + points: List[Cell], + viewer: napari.Viewer, + name: str, + cell_type: int, +) -> None: + """ + Adds all cells of cell_type Cell.TYPE to a new point layer in the napari + viewer, with given name. + """ + viewer.add_points( + cells_to_array(points, cell_type, napari_order=True), + name=name, + size=15, + n_dimensional=True, + opacity=0.6, + symbol="ring", + face_color="lightskyblue", + visible=True, + metadata=dict(point_type=cell_type), + ) + + +def cells_to_array( + cells: List[Cell], cell_type: int, napari_order: bool = True ) -> np.ndarray: """ - Convert a dataframe to an array, dropping *type_column* and re-ordering - the columns with *new_order*. + Converts all the cells of the given type as a 2D pos array. + The column order is either XYZ, otherwise it's the napari ordering + of the 3 axes (napari_points_axis_order). """ - cells_df = cells_df.drop(columns=[type_column]) - cells = cells_df[cells_df.columns[new_order]] - cells = cells.to_numpy() - return cells + cells = [c for c in cells if c.type == cell_type] + if not cells: + # make sure we return 2d array if cells is empty + return np.zeros((0, 3), dtype=np.int_) + points = np.array([(c.x, c.y, c.z) for c in cells]) + + if napari_order: + return points[:, napari_points_axis_order] + return points -def cells_to_array(cells: List[Cell]) -> Tuple[np.ndarray, np.ndarray]: - df = pd.DataFrame([c.to_dict() for c in cells]) - points = cells_df_as_np(df[df["type"] == Cell.CELL]) - rejected = cells_df_as_np(df[df["type"] == Cell.UNKNOWN]) - return points, rejected +def napari_array_to_cells( + points: napari.layers.Points, + cell_type: int, + brainglobe_order: Tuple[int, int, int] = brainglobe_points_axis_order, +) -> List[Cell]: + """ + Takes a napari Points layer and returns a list of cell objects, one for + each point in the layer. + """ + data = np.asarray(points.data)[:, brainglobe_order].tolist() + + cells = [] + for row in data: + cells.append(Cell(pos=row, cell_type=cell_type)) + + return cells diff --git a/tests/napari/test_utils.py b/tests/napari/test_utils.py index f8eff803..6cf864f3 100644 --- a/tests/napari/test_utils.py +++ b/tests/napari/test_utils.py @@ -1,22 +1,69 @@ +import numpy as np from brainglobe_utils.cells.cells import Cell from cellfinder.napari.utils import ( - add_layers, + add_classified_layers, + cells_to_array, html_label_widget, + napari_array_to_cells, + napari_points_axis_order, ) -def test_add_layers(make_napari_viewer): - """Smoke test for add_layers utility""" +def test_add_classified_layers(make_napari_viewer): + """Smoke test for add_classified_layers utility""" + cell_pos = [1, 2, 3] + unknown_pos = [4, 5, 6] points = [ - Cell(pos=[1, 2, 3], cell_type=Cell.CELL), - Cell(pos=[4, 5, 6], cell_type=Cell.UNKNOWN), + Cell(pos=cell_pos, cell_type=Cell.CELL), + Cell(pos=unknown_pos, cell_type=Cell.UNKNOWN), ] viewer = make_napari_viewer() n_layers = len(viewer.layers) - add_layers(points, viewer) # adds a "detected" and a "rejected layer" + # adds a "detected" and a "rejected layer" + add_classified_layers( + points, viewer, unknown_name="rejected", cell_name="accepted" + ) assert len(viewer.layers) == n_layers + 2 + # check names match + rej_layer = cell_layer = None + for layer in reversed(viewer.layers): + if layer.name == "accepted" and cell_layer is None: + cell_layer = layer + if layer.name == "rejected" and rej_layer is None: + rej_layer = layer + assert cell_layer is not None + assert rej_layer is not None + assert cell_layer.data is not None + assert rej_layer.data is not None + + # check data added in correct column order + # CELL types + cell_data = np.array([cell_pos]) + assert np.all( + cells_to_array(points, Cell.CELL, napari_order=False) == cell_data + ) + # convert to napari order and check it is in napari + cell_data = cell_data[:, napari_points_axis_order] + assert np.all(cell_layer.data == cell_data) + + # UNKNOWN type + rej_data = np.array([unknown_pos]) + assert np.all( + cells_to_array(points, Cell.UNKNOWN, napari_order=False) == rej_data + ) + # convert to napari order and check it is in napari + rej_data = rej_data[:, napari_points_axis_order] + assert np.all(rej_layer.data == rej_data) + + # get cells back from napari points + cells_again = napari_array_to_cells(cell_layer.data, cell_type=Cell.CELL) + cells_again.extend( + napari_array_to_cells(rej_layer.data, cell_type=Cell.UNKNOWN) + ) + assert cells_again == points + def test_html_label_widget(): """Simple unit test for the HTML Label widget""" From b1b285ce9f81815e899ac1d1d78a617349c768c0 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Wed, 1 May 2024 06:13:51 -0400 Subject: [PATCH 31/50] Support single z-stack tif file for input (#397) * Support single z-stack tif file for input. * Fix commit hook. * Apply review suggestions. --- cellfinder/core/tools/IO.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/cellfinder/core/tools/IO.py b/cellfinder/core/tools/IO.py index 84eafdd6..359a6da5 100644 --- a/cellfinder/core/tools/IO.py +++ b/cellfinder/core/tools/IO.py @@ -24,6 +24,40 @@ def get_tiff_meta( lazy_imread = delayed(imread) # lazy reader +def read_z_stack(path): + """ + Reads z-stack, lazily, if possible. + + If it's a text file or folder with 2D tiff files use dask to read lazily, + otherwise it's a single file tiff stack and is read into memory. + + :param path: Filename of text file listing 2D tiffs, folder of 2D tiffs, + or single file tiff z-stack. + :return: The data as a dask/numpy array. + """ + if path.endswith(".tiff") or path.endswith(".tif"): + with TiffFile(path) as tiff: + if not len(tiff.series): + raise ValueError( + f"Attempted to load {path} but couldn't read a z-stack" + ) + if len(tiff.series) != 1: + raise ValueError( + f"Attempted to load {path} but found multiple stacks" + ) + + axes = tiff.series[0].axes.lower() + if set(axes) != {"x", "y", "z"} or axes[0] != "z": + raise ValueError( + f"Attempted to load {path} but didn't find a zyx or " + f"zxy stack. Found {axes} axes" + ) + + return imread(path) + + return read_with_dask(path) + + def read_with_dask(path): """ Based on https://github.com/tlambert03/napari-ndtiffs From 0fd0a5e3712047b5cac138f329a3c70dd48bca66 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 1 May 2024 12:32:11 +0200 Subject: [PATCH 32/50] Remove modular asv benchmarks (#406) * remove modular asv benchmarks * recover old structure * remove asv-specific lines from gitignore and manifest * prune benchmarks --- .gitignore | 5 - MANIFEST.in | 2 +- benchmarks/README.md | 51 ++--- benchmarks/asv.conf.json | 188 ------------------ benchmarks/benchmarks/__init__.py | 0 benchmarks/benchmarks/imports.py | 43 ---- benchmarks/benchmarks/tools/IO.py | 64 ------ benchmarks/benchmarks/tools/__init__.py | 0 benchmarks/benchmarks/tools/prep.py | 65 ------ .../detect_and_classify.py | 0 benchmarks/{mem_benchmarks => }/filter_2d.py | 0 benchmarks/{mem_benchmarks => }/filter_3d.py | 0 benchmarks/mem_benchmarks/README.md | 12 -- 13 files changed, 13 insertions(+), 417 deletions(-) delete mode 100644 benchmarks/asv.conf.json delete mode 100644 benchmarks/benchmarks/__init__.py delete mode 100644 benchmarks/benchmarks/imports.py delete mode 100644 benchmarks/benchmarks/tools/IO.py delete mode 100644 benchmarks/benchmarks/tools/__init__.py delete mode 100644 benchmarks/benchmarks/tools/prep.py rename benchmarks/{mem_benchmarks => }/detect_and_classify.py (100%) rename benchmarks/{mem_benchmarks => }/filter_2d.py (100%) rename benchmarks/{mem_benchmarks => }/filter_3d.py (100%) delete mode 100644 benchmarks/mem_benchmarks/README.md diff --git a/.gitignore b/.gitignore index 31dec9e3..1d06b6b4 100644 --- a/.gitignore +++ b/.gitignore @@ -130,11 +130,6 @@ mprofile*.dat *.DS_Store -# asv -.asv -benchmarks/results -benchmarks/html -benchmarks/env # OS .DS_Store diff --git a/MANIFEST.in b/MANIFEST.in index 83b493ec..af608e03 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -12,7 +12,7 @@ exclude tox.ini graft cellfinder include cellfinder/napari/napari.yaml -prune benchmarks prune examples prune resources prune tests +prune benchmarks diff --git a/benchmarks/README.md b/benchmarks/README.md index 04a355ed..3854e59f 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -1,39 +1,12 @@ -# Benchmarking with asv -[Install asv](https://asv.readthedocs.io/en/stable/installing.html) by running: -``` -pip install asv -``` - -`asv` works roughly as follows: -1. It creates a virtual environment (as defined in the config) -2. It installs the software package version of a specific commit (or of a local commit) -3. It times the benchmarking tests and saves the results to json files -4. The json files are 'published' into an html dir -5. The html dir can be visualised in a static website - -## Running benchmarks -To run benchmarks on a specific commit: -``` -$ asv run 88fbbc33^! -``` - -To run them up to a specific commit: -``` -$ asv run 88fbbc33 -``` - -To run them on a range of commits: -``` -$ asv run 827f322b..729abcf3 -``` - -To collate the benchmarks' results into a viewable website: -``` -$ asv publish -``` -This will create a tree of files in the `html` directory, but this cannot be viewed directly from the local filesystem, so we need to put them in a static site. `asv publish` also detects statistically significant decreases of performance, the results can be inspected in the 'Regression' tab of the static site. - -To visualise the results in a static site: -``` -$ asv preview -``` +# Benchmarks +`detect_and_classify.py` contains a simple script that runs +detection and classification with the small test dataset. + +## Memory +[memory_profiler](https://github.com/pythonprofilers/memory_profiler) +can be used to profile memory useage. Install, and then run +`mprof run --include-children --multiprocess detect_and_classify.py`. It is **very** +important to use these two flags to capture memory usage by the additional +processes that cellfinder.core uses. + +To show the results of the latest profile run, run `mprof plot`. diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json deleted file mode 100644 index 1a4e32d9..00000000 --- a/benchmarks/asv.conf.json +++ /dev/null @@ -1,188 +0,0 @@ -{ - // The version of the config file format. Do not change, unless - // you know what you are doing. - "version": 1, - - // The name of the project being benchmarked - "project": "cellfinder-core", - - // The project's homepage - "project_url": "https://brainglobe.info/documentation/cellfinder/index.html", - - // The URL or local path of the source code repository for the - // project being benchmarked - // To use the upstream repository: uncomment the 1st line (and comment the 2nd) - // To use the local repository: comment the 1st line (and uncomment the 2nd) - //"repo": "https://github.com/brainglobe/cellfinder-core.git", - "repo": "..", - - // The Python project's subdirectory in your repo. If missing or - // the empty string, the project is assumed to be located at the root - // of the repository (where setup.py is located) - // "repo_subdir": "", - - // Customizable commands for building, installing, and - // uninstalling the project. See asv.conf.json documentation. - // - "install_command": ["in-dir={env_dir} python -mpip install {wheel_file}"], - "uninstall_command": ["return-code=any python -mpip uninstall -y {project}"], - "build_command": [ - "python -m pip install build", - "python -m build", - "PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}" - ], - - // List of branches to benchmark. If not provided, defaults to "master" - // (for git) or "default" (for mercurial). - "branches": ["main"], // for git - // "branches": ["default"], // for mercurial - - // The DVCS being used. If not set, it will be automatically - // determined from "repo" by looking at the protocol in the URL - // (if remote), or by looking for special directories, such as - // ".git" (if local). - // "dvcs": "git", - - // The tool to use to create environments. May be "conda", - // "virtualenv" or other value depending on the plugins in use. - // If missing or the empty string, the tool will be automatically - // determined by looking for tools on the PATH environment - // variable. - "environment_type": "conda", - - // timeout in seconds for installing any dependencies in environment - // defaults to 10 min - //"install_timeout": 600, - - // the base URL to show a commit for the project. - "show_commit_url": "http://github.com/brainglobe/cellfinder-core/commit/", - - // The Pythons you'd like to test against. If not provided, defaults - // to the current version of Python used to run `asv`. - "pythons": ["3.10"], // same as pyproject.toml? ["3.8", "3.9", "3.10"] - - // The list of conda channel names to be searched for benchmark - // dependency packages in the specified order - "conda_channels": ["conda-forge", "defaults"], - - // A conda environment file that is used for environment creation. - // "conda_environment_file": "environment.yml", - - // The matrix of dependencies to test. Each key of the "req" - // requirements dictionary is the name of a package (in PyPI) and - // the values are version numbers. An empty list or empty string - // indicates to just test against the default (latest) - // version. null indicates that the package is to not be - // installed. If the package to be tested is only available from - // PyPi, and the 'environment_type' is conda, then you can preface - // the package name by 'pip+', and the package will be installed - // via pip (with all the conda available packages installed first, - // followed by the pip installed packages). - // - // The ``@env`` and ``@env_nobuild`` keys contain the matrix of - // environment variables to pass to build and benchmark commands. - // An environment will be created for every combination of the - // cartesian product of the "@env" variables in this matrix. - // Variables in "@env_nobuild" will be passed to every environment - // during the benchmark phase, but will not trigger creation of - // new environments. A value of ``null`` means that the variable - // will not be set for the current combination. - // - "matrix": { - "req": {}, - // "napari": ["", null], // test with and without - // // "six": ["", null], // test with and without six installed - // // "pip+emcee": [""] // emcee is only available for install with pip. - // }, - // "env": {"ENV_VAR_1": ["val1", "val2"]}, - // "env_nobuild": {"ENV_VAR_2": ["val3", null]}, - }, - - // Combinations of libraries/python versions can be excluded/included - // from the set to test. Each entry is a dictionary containing additional - // key-value pairs to include/exclude. - // - // An exclude entry excludes entries where all values match. The - // values are regexps that should match the whole string. - // - // An include entry adds an environment. Only the packages listed - // are installed. The 'python' key is required. The exclude rules - // do not apply to includes. - // - // In addition to package names, the following keys are available: - // - // - python - // Python version, as in the *pythons* variable above. - // - environment_type - // Environment type, as above. - // - sys_platform - // Platform, as in sys.platform. Possible values for the common - // cases: 'linux2', 'win32', 'cygwin', 'darwin'. - // - req - // Required packages - // - env - // Environment variables - // - env_nobuild - // Non-build environment variables - // - // "exclude": [ - // {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows - // {"environment_type": "conda", "req": {"six": null}}, // don't run without six on conda - // {"env": {"ENV_VAR_1": "val2"}}, // skip val2 for ENV_VAR_1 - // ], - // - // "include": [ - // // additional env for python2.7 - // {"python": "2.7", "req": {"numpy": "1.8"}, "env_nobuild": {"FOO": "123"}}, - // // additional env if run on windows+conda - // {"platform": "win32", "environment_type": "conda", "python": "2.7", "req": {"libpython": ""}}, - // ], - - // The directory (relative to the current directory) that benchmarks are - // stored in. If not provided, defaults to "benchmarks" - "benchmark_dir": "benchmarks", - - // The directory (relative to the current directory) to cache the Python - // environments in. If not provided, defaults to "env" - "env_dir": "env", - - // The directory (relative to the current directory) that raw benchmark - // results are stored in. If not provided, defaults to "results". - "results_dir": "results", - - // The directory (relative to the current directory) that the html tree - // should be written to. If not provided, defaults to "html". - "html_dir": "html", - - // The number of characters to retain in the commit hashes. - // "hash_length": 8, - - // `asv` will cache results of the recent builds in each - // environment, making them faster to install next time. This is - // the number of builds to keep, per environment. - "build_cache_size": 2, - - // The commits after which the regression search in `asv publish` - // should start looking for regressions. Dictionary whose keys are - // regexps matching to benchmark names, and values corresponding to - // the commit (exclusive) after which to start looking for - // regressions. The default is to start from the first commit - // with results. If the commit is `null`, regression detection is - // skipped for the matching benchmark. - // - // "regressions_first_commits": { - // "some_benchmark": "352cdf", // Consider regressions only after this commit - // "another_benchmark": null, // Skip regression detection altogether - // }, - - // The thresholds for relative change in results, after which `asv - // publish` starts reporting regressions. Dictionary of the same - // form as in ``regressions_first_commits``, with values - // indicating the thresholds. If multiple entries match, the - // maximum is taken. If no entry matches, the default is 5%. - // - // "regressions_thresholds": { - // "some_benchmark": 0.01, // Threshold of 1% - // "another_benchmark": 0.5, // Threshold of 50% - // }, -} diff --git a/benchmarks/benchmarks/__init__.py b/benchmarks/benchmarks/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/benchmarks/benchmarks/imports.py b/benchmarks/benchmarks/imports.py deleted file mode 100644 index e7356cdd..00000000 --- a/benchmarks/benchmarks/imports.py +++ /dev/null @@ -1,43 +0,0 @@ -# ------------------------------------ -# Runtime benchmarks -# ------------------------------------ -def timeraw_import_main(): - return """ - from cellfinder.core.main import main - """ - - -def timeraw_import_io_dask(): - return """ - from cellfinder.core.tools.IO import read_with_dask - """ - - -def timeraw_import_io_tiff_meta(): - return """ - from cellfinder.core.tools.IO import get_tiff_meta - """ - - -def timeraw_import_prep_tensorflow(): - return """ - from cellfinder.core.tools.prep import prep_tensorflow - """ - - -def timeraw_import_prep_models(): - return """ - from cellfinder.core.tools.prep import prep_models - """ - - -def timeraw_import_prep_classification(): - return """ - from cellfinder.core.tools.prep import prep_classification - """ - - -def timeraw_import_prep_training(): - return """ - from cellfinder.core.tools.prep import prep_training - """ diff --git a/benchmarks/benchmarks/tools/IO.py b/benchmarks/benchmarks/tools/IO.py deleted file mode 100644 index 57d3733f..00000000 --- a/benchmarks/benchmarks/tools/IO.py +++ /dev/null @@ -1,64 +0,0 @@ -from pathlib import Path - -from cellfinder.core.tools.IO import get_tiff_meta, read_with_dask - -CELLFINDER_CORE_PATH = Path(__file__).parents[3] -TESTS_DATA_INTEGRATION_PATH = ( - Path(CELLFINDER_CORE_PATH) / "tests" / "data" / "integration" -) - - -class Read: - # ------------------------------------ - # Data - # ------------------------------ - detection_crop_planes_ch0 = TESTS_DATA_INTEGRATION_PATH / Path( - "detection", "crop_planes", "ch0" - ) - detection_crop_planes_ch1 = TESTS_DATA_INTEGRATION_PATH / Path( - "detection", "crop_planes", "ch1" - ) - cells_tif_files = list( - Path(TESTS_DATA_INTEGRATION_PATH, "training", "cells").glob("*.tif") - ) - non_cells_tif_files = list( - Path(TESTS_DATA_INTEGRATION_PATH, "training", "non_cells").glob( - "*.tif" - ) - ) - - # --------------------------------------------- - # Setup function - # -------------------------------------------- - def setup(self, subdir): - self.data_dir = str(subdir) - - # --------------------------------------------- - # Reading 3d arrays with dask - # -------------------------------------------- - def time_read_with_dask(self, subdir): - read_with_dask(self.data_dir) - - # parameters to sweep across - time_read_with_dask.param_names = [ - "tests_data_integration_subdir", - ] - time_read_with_dask.params = ( - [detection_crop_planes_ch0, detection_crop_planes_ch1], - ) - - # ----------------------------------------------- - # Reading metadata from tif files - # ------------------------------------------------- - def time_get_tiff_meta( - self, - subdir, - ): - get_tiff_meta(self.data_dir) - - # parameters to sweep across - time_get_tiff_meta.param_names = [ - "tests_data_integration_tiffile", - ] - - time_get_tiff_meta.params = cells_tif_files + non_cells_tif_files diff --git a/benchmarks/benchmarks/tools/__init__.py b/benchmarks/benchmarks/tools/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/benchmarks/benchmarks/tools/prep.py b/benchmarks/benchmarks/tools/prep.py deleted file mode 100644 index 0e2bbce4..00000000 --- a/benchmarks/benchmarks/tools/prep.py +++ /dev/null @@ -1,65 +0,0 @@ -import shutil -from pathlib import Path - -from brainglobe_utils.general.system import get_num_processes - -from cellfinder.core.tools.prep import ( - prep_model_weights, - prep_models, - prep_tensorflow, -) - - -class PrepModels: - # parameters to sweep across - param_names = ["model_name"] - params = ["resnet50_tv", "resnet50_all"] - - # increase default timeout to allow for download - timeout = 600 - - # install path - def benchmark_install_path(self): - # also allow to run as "user" on GH actions? - return Path(Path.home() / ".cellfinder-benchmarks") - - def setup(self, model_name): - self.n_free_cpus = 2 - self.n_processes = get_num_processes( - min_free_cpu_cores=self.n_free_cpus - ) - self.trained_model = None - self.model_weights = None - self.install_path = self.benchmark_install_path() - self.model_name = model_name - - # remove .cellfinder-benchmarks dir if it exists - shutil.rmtree(self.install_path, ignore_errors=True) - - def teardown(self, model_name): - # remove .cellfinder-benchmarks dir after benchmarks - shutil.rmtree(self.install_path) - - def time_prep_models(self, model_name): - prep_models( - self.model_weights, - self.install_path, - model_name, - ) - - def time_prep_classification(self, model_name): - prep_model_weights( - self.model_weights, - self.install_path, - model_name, - self.n_free_cpus, - ) - - -class PrepTF: - def setup(self): - n_free_cpus = 2 - self.n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) - - def time_prep_tensorflow(self): - prep_tensorflow(self.n_processes) diff --git a/benchmarks/mem_benchmarks/detect_and_classify.py b/benchmarks/detect_and_classify.py similarity index 100% rename from benchmarks/mem_benchmarks/detect_and_classify.py rename to benchmarks/detect_and_classify.py diff --git a/benchmarks/mem_benchmarks/filter_2d.py b/benchmarks/filter_2d.py similarity index 100% rename from benchmarks/mem_benchmarks/filter_2d.py rename to benchmarks/filter_2d.py diff --git a/benchmarks/mem_benchmarks/filter_3d.py b/benchmarks/filter_3d.py similarity index 100% rename from benchmarks/mem_benchmarks/filter_3d.py rename to benchmarks/filter_3d.py diff --git a/benchmarks/mem_benchmarks/README.md b/benchmarks/mem_benchmarks/README.md deleted file mode 100644 index 3854e59f..00000000 --- a/benchmarks/mem_benchmarks/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# Benchmarks -`detect_and_classify.py` contains a simple script that runs -detection and classification with the small test dataset. - -## Memory -[memory_profiler](https://github.com/pythonprofilers/memory_profiler) -can be used to profile memory useage. Install, and then run -`mprof run --include-children --multiprocess detect_and_classify.py`. It is **very** -important to use these two flags to capture memory usage by the additional -processes that cellfinder.core uses. - -To show the results of the latest profile run, run `mprof plot`. From 6b529dc72110114f0d17c8139e8645f8fecef317 Mon Sep 17 00:00:00 2001 From: Alessandro Felder Date: Wed, 1 May 2024 11:39:18 +0100 Subject: [PATCH 33/50] Adapt CI so it covers both new and old Macs, and installs required additional dependencies on M1 (#408) * naive attempt at adapting to silicon mac CI * run include guard test on Silicon CI * double-check hdf5 is needed --- .github/workflows/test_and_deploy.yml | 12 +++++++----- .github/workflows/test_include_guard.yaml | 7 ++++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 2bf9e976..5197f102 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -42,12 +42,14 @@ jobs: # Run all supported Python versions on linux os: [ubuntu-latest] python-version: ["3.9", "3.10"] - # Include one windows, one macos run + # Include one windows, one macos run each for M1 (latest) and Intel (13) include: - - os: macos-latest - python-version: "3.10" - - os: windows-latest - python-version: "3.10" + - os: macos-13 + python-version: "3.10" + - os: macos-latest + python-version: "3.10" + - os: windows-latest + python-version: "3.10" steps: # Cache the tensorflow model so we don't have to remake it every time diff --git a/.github/workflows/test_include_guard.yaml b/.github/workflows/test_include_guard.yaml index 26277d98..70415a2e 100644 --- a/.github/workflows/test_include_guard.yaml +++ b/.github/workflows/test_include_guard.yaml @@ -35,7 +35,12 @@ jobs: import cellfinder.core import cellfinder.napari - - name: Uninstall tensorflow + - name: Uninstall tensorflow-macos on Mac M1 + if: matrix.os == 'macos-latest' + run: python -m pip uninstall -y tensorflow-macos + + - name: Uninstall tensorflow on Ubuntu + if: matrix.os == 'ubuntu-latest' run: python -m pip uninstall -y tensorflow - name: Test (broken) import From a28305618a228fc5175e05f72b555d7b29073bc5 Mon Sep 17 00:00:00 2001 From: Alessandro Felder Date: Fri, 3 May 2024 14:34:33 +0100 Subject: [PATCH 34/50] Optimize cell detection (#398) (#407) * Replace coord map values with numba list/tuple for optim. * Switch to fortran layout for faster update of last dim. * Cache kernel. * jit ball filter. * Put z as first axis to speed z rolling (row-major memory). * Unroll recursion (no perf impact either way). * Parallelize cell cluster splitting. * Parallelize walking for full images. * Cleanup docs and pep8 etc. * Add pre-commit fixes. * Fix parallel always being selected and numba function 1st class warning. * Run hook. * Older python needs Union instead of |. * Accept review suggestion. * Address review changes. * num_threads must be an int. --------- Co-authored-by: Matt Einhorn --- cellfinder/core/detect/detect.py | 13 +- .../core/detect/filters/volume/ball_filter.py | 311 +++++++++++------- .../filters/volume/structure_detection.py | 146 +++++--- .../filters/volume/structure_splitting.py | 2 +- .../detect/filters/volume/volume_filter.py | 97 +++--- .../test_structure_detection.py | 2 +- 6 files changed, 365 insertions(+), 206 deletions(-) diff --git a/cellfinder/core/detect/detect.py b/cellfinder/core/detect/detect.py index c28ea557..9d70e541 100644 --- a/cellfinder/core/detect/detect.py +++ b/cellfinder/core/detect/detect.py @@ -22,6 +22,7 @@ import numpy as np from brainglobe_utils.cells.cells import Cell from brainglobe_utils.general.system import get_num_processes +from numba import set_num_threads from cellfinder.core import logger, types from cellfinder.core.detect.filters.plane import TileProcessor @@ -157,6 +158,13 @@ def main( ) n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) n_ball_procs = max(n_processes - 1, 1) + + # we parallelize 2d filtering, which typically lags behind the 3d + # processing so for n_ball_procs 2d filtering threads, ball_z_size will + # typically be in use while the others stall waiting for 3d processing + # so we can use those for other things, such as numba threading + set_num_threads(max(n_ball_procs - int(ball_z_size), 1)) + start_time = datetime.now() ( @@ -236,7 +244,10 @@ def main( # then 3D filtering has finished. As batches of planes are filtered # by the 3D filter, it releases the locks of subsequent 2D filter # processes. - cells = mp_3d_filter.process(async_results, locks, callback=callback) + mp_3d_filter.process(async_results, locks, callback=callback) + + # it's now done filtering, get results with pool + cells = mp_3d_filter.get_results(worker_pool) time_elapsed = datetime.now() - start_time logger.debug( diff --git a/cellfinder/core/detect/filters/volume/ball_filter.py b/cellfinder/core/detect/filters/volume/ball_filter.py index ebd3642e..c5f5f5b8 100644 --- a/cellfinder/core/detect/filters/volume/ball_filter.py +++ b/cellfinder/core/detect/filters/volume/ball_filter.py @@ -1,12 +1,83 @@ +from functools import lru_cache + import numpy as np -from numba import njit +from numba import njit, objmode, prange +from numba.core import types +from numba.experimental import jitclass from cellfinder.core.tools.array_operations import bin_mean_3d from cellfinder.core.tools.geometry import make_sphere DEBUG = False +uint32_3d_type = types.uint32[:, :, :] +bool_3d_type = types.bool_[:, :, :] +float_3d_type = types.float64[:, :, :] + + +@lru_cache(maxsize=50) +def get_kernel(ball_xy_size: int, ball_z_size: int) -> np.ndarray: + # Create a spherical kernel. + # + # This is done by: + # 1. Generating a binary sphere at a resolution *upscale_factor* larger + # than desired. + # 2. Downscaling the binary sphere to get a 'fuzzy' sphere at the + # original intended scale + upscale_factor: int = 7 + upscaled_kernel_shape = ( + upscale_factor * ball_xy_size, + upscale_factor * ball_xy_size, + upscale_factor * ball_z_size, + ) + upscaled_ball_centre_position = ( + np.floor(upscaled_kernel_shape[0] / 2), + np.floor(upscaled_kernel_shape[1] / 2), + np.floor(upscaled_kernel_shape[2] / 2), + ) + upscaled_ball_radius = upscaled_kernel_shape[0] / 2.0 + + sphere_kernel = make_sphere( + upscaled_kernel_shape, + upscaled_ball_radius, + upscaled_ball_centre_position, + ) + sphere_kernel = sphere_kernel.astype(np.float64) + kernel = bin_mean_3d( + sphere_kernel, + bin_height=upscale_factor, + bin_width=upscale_factor, + bin_depth=upscale_factor, + ) + + assert ( + kernel.shape[2] == ball_z_size + ), "Kernel z dimension should be {}, got {}".format( + ball_z_size, kernel.shape[2] + ) + + return kernel + + +# volume indices/size is 64 bit for very large brains(!) +spec = [ + ("ball_xy_size", types.uint32), + ("ball_z_size", types.uint32), + ("tile_step_width", types.uint64), + ("tile_step_height", types.uint64), + ("THRESHOLD_VALUE", types.uint32), + ("SOMA_CENTRE_VALUE", types.uint32), + ("overlap_fraction", types.float64), + ("overlap_threshold", types.float64), + ("middle_z_idx", types.uint32), + ("_num_z_added", types.uint32), + ("kernel", float_3d_type), + ("volume", uint32_3d_type), + ("inside_brain_tiles", bool_3d_type), +] + +@jitclass(spec=spec) class BallFilter: """ A 3D ball filter. @@ -62,72 +133,39 @@ def __init__( self.THRESHOLD_VALUE = threshold_value self.SOMA_CENTRE_VALUE = soma_centre_value - # Create a spherical kernel. - # - # This is done by: - # 1. Generating a binary sphere at a resolution *upscale_factor* larger - # than desired. - # 2. Downscaling the binary sphere to get a 'fuzzy' sphere at the - # original intended scale - upscale_factor: int = 7 - upscaled_kernel_shape = ( - upscale_factor * ball_xy_size, - upscale_factor * ball_xy_size, - upscale_factor * ball_z_size, - ) - upscaled_ball_centre_position = ( - np.floor(upscaled_kernel_shape[0] / 2), - np.floor(upscaled_kernel_shape[1] / 2), - np.floor(upscaled_kernel_shape[2] / 2), - ) - upscaled_ball_radius = upscaled_kernel_shape[0] / 2.0 - sphere_kernel = make_sphere( - upscaled_kernel_shape, - upscaled_ball_radius, - upscaled_ball_centre_position, - ) - sphere_kernel = sphere_kernel.astype(np.float64) - self.kernel = bin_mean_3d( - sphere_kernel, - bin_height=upscale_factor, - bin_width=upscale_factor, - bin_depth=upscale_factor, - ) - - assert ( - self.kernel.shape[2] == ball_z_size - ), "Kernel z dimension should be {}, got {}".format( - ball_z_size, self.kernel.shape[2] - ) + # getting kernel is not jitted + with objmode(kernel=float_3d_type): + kernel = get_kernel(ball_xy_size, ball_z_size) + self.kernel = kernel self.overlap_threshold = np.sum(self.overlap_fraction * self.kernel) # Stores the current planes that are being filtered + # first axis is z for faster rotating the z-axis self.volume = np.empty( - (plane_width, plane_height, ball_z_size), dtype=np.uint32 + (ball_z_size, plane_width, plane_height), + dtype=np.uint32, ) # Index of the middle plane in the volume self.middle_z_idx = int(np.floor(ball_z_size / 2)) + self._num_z_added = 0 - # TODO: lazy initialisation + # first axis is z self.inside_brain_tiles = np.empty( ( + ball_z_size, int(np.ceil(plane_width / tile_step_width)), int(np.ceil(plane_height / tile_step_height)), - ball_z_size, ), - dtype=bool, + dtype=np.bool_, ) - # Stores the z-index in volume at which new planes are inserted when - # append() is called - self.__current_z = -1 @property def ready(self) -> bool: """ Return `True` if enough planes have been appended to run the filter. """ - return self.__current_z == self.ball_z_size - 1 + return self._num_z_added >= self.ball_z_size def append(self, plane: np.ndarray, mask: np.ndarray) -> None: """ @@ -135,76 +173,106 @@ def append(self, plane: np.ndarray, mask: np.ndarray) -> None: """ if DEBUG: assert [e for e in plane.shape[:2]] == [ - e for e in self.volume.shape[:2] + e for e in self.volume.shape[1:] ], 'plane shape mismatch, expected "{}", got "{}"'.format( - [e for e in self.volume.shape[:2]], + [e for e in self.volume.shape[1:]], [e for e in plane.shape[:2]], ) assert [e for e in mask.shape[:2]] == [ - e for e in self.inside_brain_tiles.shape[:2] + e for e in self.inside_brain_tiles.shape[1:] ], 'mask shape mismatch, expected"{}", got {}"'.format( - [e for e in self.inside_brain_tiles.shape[:2]], + [e for e in self.inside_brain_tiles.shape[1:]], [e for e in mask.shape[:2]], ) - if not self.ready: - self.__current_z += 1 - else: + + if self.ready: # Shift everything down by one to make way for the new plane - self.volume = np.roll( - self.volume, -1, axis=2 - ) # WARNING: not in place - self.inside_brain_tiles = np.roll( - self.inside_brain_tiles, -1, axis=2 - ) + # this is faster than np.roll, especially with z-axis first + self.volume[:-1, :, :] = self.volume[1:, :, :] + self.inside_brain_tiles[:-1, :, :] = self.inside_brain_tiles[ + 1:, :, : + ] + + # index for *next* slice is num we added *so far* until max + idx = min(self._num_z_added, self.ball_z_size - 1) + self._num_z_added += 1 + # Add the new plane to the top of volume and inside_brain_tiles - self.volume[:, :, self.__current_z] = plane[:, :] - self.inside_brain_tiles[:, :, self.__current_z] = mask[:, :] + self.volume[idx, :, :] = plane + self.inside_brain_tiles[idx, :, :] = mask def get_middle_plane(self) -> np.ndarray: """ Get the plane in the middle of self.volume. """ - z = self.middle_z_idx - return np.array(self.volume[:, :, z], dtype=np.uint32) + return self.volume[self.middle_z_idx, :, :].copy() - def walk(self) -> None: # Highly optimised because most time critical + def walk(self, parallel: bool = False) -> None: + # **don't** pass parallel as keyword arg - numba struggles with it + # Highly optimised because most time critical ball_radius = self.ball_xy_size // 2 # Get extents of image that are covered by tiles tile_mask_covered_img_width = ( - self.inside_brain_tiles.shape[0] * self.tile_step_width + self.inside_brain_tiles.shape[1] * self.tile_step_width ) tile_mask_covered_img_height = ( - self.inside_brain_tiles.shape[1] * self.tile_step_height + self.inside_brain_tiles.shape[2] * self.tile_step_height ) # Get maximum offsets for the ball max_width = tile_mask_covered_img_width - self.ball_xy_size max_height = tile_mask_covered_img_height - self.ball_xy_size - _walk( - max_height, - max_width, - self.tile_step_width, - self.tile_step_height, - self.inside_brain_tiles, - self.volume, - self.kernel, - ball_radius, - self.middle_z_idx, - self.overlap_threshold, - self.THRESHOLD_VALUE, - self.SOMA_CENTRE_VALUE, - ) + # we have to pass the raw volume so walk doesn't use its edits as it + # processes the volume. self.volume is the one edited in place + input_volume = self.volume.copy() + + if parallel: + _walk_parallel( + max_height, + max_width, + self.tile_step_width, + self.tile_step_height, + self.inside_brain_tiles, + input_volume, + self.volume, + self.kernel, + ball_radius, + self.middle_z_idx, + self.overlap_threshold, + self.THRESHOLD_VALUE, + self.SOMA_CENTRE_VALUE, + ) + else: + _walk_single( + max_height, + max_width, + self.tile_step_width, + self.tile_step_height, + self.inside_brain_tiles, + input_volume, + self.volume, + self.kernel, + ball_radius, + self.middle_z_idx, + self.overlap_threshold, + self.THRESHOLD_VALUE, + self.SOMA_CENTRE_VALUE, + ) @njit(cache=True) def _cube_overlaps( - cube: np.ndarray, + volume: np.ndarray, + x_start: int, + x_end: int, + y_start: int, + y_end: int, overlap_threshold: float, - THRESHOLD_VALUE: int, + threshold_value: int, kernel: np.ndarray, ) -> bool: # Highly optimised because most time critical """ - For each pixel in cube that is greater than THRESHOLD_VALUE, sum + For each pixel in cube in volume that is greater than THRESHOLD_VALUE, sum up the corresponding pixels in *kernel*. If the total is less than overlap_threshold, return False, otherwise return True. @@ -214,23 +282,26 @@ def _cube_overlaps( Parameters ---------- - cube : + volume : 3D array. + x_start, x_end, y_start, y_end : + The start and end indices in volume that form the cube. End is + exclusive overlap_threshold : Threshold above which to return True. - THRESHOLD_VALUE : + threshold_value : Value above which a pixel is marked as being part of a cell. kernel : - 3D array, with the same shape as *cube*. + 3D array, with the same shape as *cube* in the volume. """ - current_overlap_value = 0 + current_overlap_value = 0.0 - middle = np.floor(cube.shape[2] / 2) + 1 + middle = np.floor(volume.shape[0] / 2) + 1 halfway_overlap_thresh = ( overlap_threshold * 0.4 ) # FIXME: do not hard code value - for z in range(cube.shape[2]): + for z in range(volume.shape[0]): # TODO: OPTIMISE: step from middle to outer boundaries to check # more data first # @@ -238,11 +309,17 @@ def _cube_overlaps( # 0.4 * the overlap threshold, return if z == middle and current_overlap_value < halfway_overlap_thresh: return False # DEBUG: optimisation attempt - for y in range(cube.shape[1]): - for x in range(cube.shape[0]): + + for y in range(y_start, y_end): + for x in range(x_start, x_end): # includes self.SOMA_CENTRE_VALUE - if cube[x, y, z] >= THRESHOLD_VALUE: - current_overlap_value += kernel[x, y, z] + if volume[z, x, y] >= threshold_value: + # x/y must be shifted in kernel because we x/y is relative + # to the full volume, so shift it to relative to the cube + current_overlap_value += kernel[ + x - x_start, y - y_start, z + ] + return current_overlap_value > overlap_threshold @@ -260,23 +337,23 @@ def _is_tile_to_check( """ x_in_mask = x // tile_step_width # TEST: test bounds (-1 range) y_in_mask = y // tile_step_height # TEST: test bounds (-1 range) - return inside_brain_tiles[x_in_mask, y_in_mask, middle_z] + return inside_brain_tiles[middle_z, x_in_mask, y_in_mask] -@njit -def _walk( +def _walk_base( max_height: int, max_width: int, tile_step_width: int, tile_step_height: int, inside_brain_tiles: np.ndarray, + input_volume: np.ndarray, volume: np.ndarray, kernel: np.ndarray, ball_radius: int, middle_z: int, overlap_threshold: float, - THRESHOLD_VALUE: int, - SOMA_CENTRE_VALUE: int, + threshold_value: int, + soma_centre_value: int, ) -> None: """ Scan through *volume*, and mark pixels where there are enough surrounding @@ -289,23 +366,28 @@ def _walk( max_height, max_width : Maximum offsets for the ball filter. inside_brain_tiles : - Array containing information on whether a tile is inside the brain - or not. Tiles outside the brain are skipped. + 3d array containing information on whether a tile is + inside the brain or not. Tiles outside the brain are skipped. + input_volume : + 3D array containing the plane-filtered data passed to the function + before walking. volume is edited in place, so this is the original + volume to prevent the changes for some cubes affective other cubes + during a single walk call. volume : - 3D array containing the plane-filtered data. + 3D array containing the plane-filtered data - edited in place. kernel : 3D array ball_radius : Radius of the ball in the xy plane. - SOMA_CENTRE_VALUE : + soma_centre_value : Value that is used to mark pixels in *volume*. Notes ----- Warning: modifies volume in place! """ - for y in range(max_height): - for x in range(max_width): + for y in prange(max_height): + for x in prange(max_width): ball_centre_x = x + ball_radius ball_centre_y = y + ball_radius if _is_tile_to_check( @@ -316,17 +398,20 @@ def _walk( tile_step_height, inside_brain_tiles, ): - cube = volume[ - x : x + kernel.shape[0], - y : y + kernel.shape[1], - :, - ] if _cube_overlaps( - cube, + input_volume, + x, + x + kernel.shape[0], + y, + y + kernel.shape[1], overlap_threshold, - THRESHOLD_VALUE, + threshold_value, kernel, ): - volume[ball_centre_x, ball_centre_y, middle_z] = ( - SOMA_CENTRE_VALUE + volume[middle_z, ball_centre_x, ball_centre_y] = ( + soma_centre_value ) + + +_walk_parallel = njit(parallel=True)(_walk_base) +_walk_single = njit(parallel=False)(_walk_base) diff --git a/cellfinder/core/detect/filters/volume/structure_detection.py b/cellfinder/core/detect/filters/volume/structure_detection.py index 6147b6d3..536f00ad 100644 --- a/cellfinder/core/detect/filters/volume/structure_detection.py +++ b/cellfinder/core/detect/filters/volume/structure_detection.py @@ -1,14 +1,22 @@ from dataclasses import dataclass -from typing import Dict, Optional, TypeVar +from typing import Dict, Optional, Tuple, TypeVar, Union import numba.typed import numpy as np import numpy.typing as npt -from numba import njit +from numba import njit, typed from numba.core import types from numba.experimental import jitclass from numba.types import DictType +T = TypeVar("T") +# type used for the domain of the volume - the size of the vol +vol_np_type = np.int64 +vol_numba_type = types.int64 +# type used for the structure id +sid_np_type = np.int64 +sid_numba_type = types.int64 + @dataclass class Point: @@ -32,18 +40,15 @@ def get_non_zero_dtype_min(values: np.ndarray) -> int: return min_val -T = TypeVar("T") - - @njit def traverse_dict(d: Dict[T, T], a: T) -> T: """ Traverse d, until a is not present as a key. """ - if a in d: - return traverse_dict(d, d[a]) - else: - return a + value = a + while value in d: + value = d[value] + return value @njit @@ -54,14 +59,28 @@ def get_structure_centre(structure: np.ndarray) -> np.ndarray: Centre calculated as the mean of each pixel coordinate, rounded to the nearest integer. """ - # can't do np.mean(structure, axis=0) - # because axis is not supported by numba + # numba support axis for sum, but not mean + return np.round(np.sum(structure, axis=0) / structure.shape[0]) + + +@njit +def _get_structure_centre(structure: types.ListType) -> np.ndarray: + # See get_structure_centre. + # this is for our own points stored as list optimized by numba + a_sum = 0.0 + b_sum = 0.0 + c_sum = 0.0 + for a, b, c in structure: + a_sum += a + b_sum += b + c_sum += c + return np.round( np.array( [ - np.mean(structure[:, 0]), - np.mean(structure[:, 1]), - np.mean(structure[:, 2]), + a_sum / len(structure), + b_sum / len(structure), + c_sum / len(structure), ] ) ) @@ -69,15 +88,18 @@ def get_structure_centre(structure: np.ndarray) -> np.ndarray: # Type declaration has to come outside of the class, # see https://github.com/numba/numba/issues/8808 -uint_2d_type = types.uint64[:, :] +tuple_point_type = types.Tuple( + (vol_numba_type, vol_numba_type, vol_numba_type) +) +list_of_points_type = types.ListType(tuple_point_type) spec = [ - ("z", types.uint64), - ("next_structure_id", types.uint64), - ("shape", types.UniTuple(types.int64, 2)), - ("obsolete_ids", DictType(types.int64, types.int64)), - ("coords_maps", DictType(types.uint64, uint_2d_type)), + ("z", vol_numba_type), + ("next_structure_id", sid_numba_type), + ("shape", types.UniTuple(vol_numba_type, 2)), + ("obsolete_ids", DictType(sid_numba_type, sid_numba_type)), + ("coords_maps", DictType(sid_numba_type, list_of_points_type)), ] @@ -103,8 +125,12 @@ class CellDetector: are scanned. coords_maps : Mapping from structure ID to the coordinates of pixels within that - structure. Coordinates are stored in a 2D array, with the second - axis indexing (x, y, z) coordinates. + structure. Coordinates are stored in a list of (x, y, z) tuples of + the coordinates. + + Use `get_structures` to get it as a dict whose values are each + a 2D array, where rows are points, and columns x, y, z of the + points. """ def __init__(self, width: int, height: int, start_z: int): @@ -123,11 +149,11 @@ def __init__(self, width: int, height: int, start_z: int): # Mapping from obsolete IDs to the IDs that they have been # made obsolete by self.obsolete_ids = numba.typed.Dict.empty( - key_type=types.int64, value_type=types.int64 + key_type=sid_numba_type, value_type=sid_numba_type ) # Mapping from IDs to list of points in that structure self.coords_maps = numba.typed.Dict.empty( - key_type=types.int64, value_type=uint_2d_type + key_type=sid_numba_type, value_type=list_of_points_type ) def process( @@ -136,7 +162,7 @@ def process( """ Process a new plane. """ - if [e for e in plane.shape[:2]] != [e for e in self.shape]: + if plane.shape[:2] != self.shape: raise ValueError("plane does not have correct shape") plane = self.connect_four(plane, previous_plane) @@ -166,7 +192,7 @@ def connect_four( for x in range(plane.shape[0]): if plane[x, y] == SOMA_CENTRE_VALUE: # Labels of structures below, left and behind - neighbour_ids = np.zeros(3, dtype=np.uint64) + neighbour_ids = np.zeros(3, dtype=sid_np_type) # If in bounds look at neighbours if x > 0: neighbour_ids[0] = plane[x - 1, y] @@ -191,17 +217,54 @@ def connect_four( def get_cell_centres(self) -> np.ndarray: return self.structures_to_cells() - def get_coords_dict(self) -> Dict: - return self.coords_maps + def get_structures(self) -> Dict[int, np.ndarray]: + """ + Gets the structures as a dict of structure IDs mapped to the 2D array + of structure points. + """ + d = {} + for sid, points in self.coords_maps.items(): + # numba silliness - it cannot handle + # `item = np.array(points, dtype=vol_np_type)` so we need to create + # array and then fill in the point + item = np.empty((len(points), 3), dtype=vol_np_type) + d[sid] = item + + for i, point in enumerate(points): + item[i, :] = point + + return d + + def add_point( + self, sid: int, point: Union[tuple, list, np.ndarray] + ) -> None: + """ + Add single 3d *point* to the structure with the given *sid*. + """ + if sid not in self.coords_maps: + self.coords_maps[sid] = typed.List.empty_list(tuple_point_type) + + self._add_point(sid, (int(point[0]), int(point[1]), int(point[2]))) - def add_point(self, sid: int, point: np.ndarray) -> None: + def add_points(self, sid: int, points: np.ndarray): """ - Add *point* to the structure with the given *sid*. + Adds ndarray of *points* to the structure with the given *sid*. + Each row is a 3d point. """ - self.coords_maps[sid] = np.row_stack((self.coords_maps[sid], point)) + if sid not in self.coords_maps: + self.coords_maps[sid] = typed.List.empty_list(tuple_point_type) + + append = self.coords_maps[sid].append + pts = np.round(points).astype(vol_np_type) + for point in pts: + append((point[0], point[1], point[2])) + + def _add_point(self, sid: int, point: Tuple[int, int, int]) -> None: + # sid must exist + self.coords_maps[sid].append(point) def add( - self, x: int, y: int, z: int, neighbour_ids: npt.NDArray[np.uint64] + self, x: int, y: int, z: int, neighbour_ids: npt.NDArray[sid_np_type] ) -> int: """ For the current coordinates takes all the neighbours and find the @@ -215,17 +278,16 @@ def add( """ updated_id = self.sanitise_ids(neighbour_ids) if updated_id not in self.coords_maps: - self.coords_maps[updated_id] = np.zeros( - shape=(0, 3), dtype=np.uint64 + self.coords_maps[updated_id] = typed.List.empty_list( + tuple_point_type ) self.merge_structures(updated_id, neighbour_ids) # Add point for that structure - point = np.array([[x, y, z]], dtype=np.uint64) - self.add_point(updated_id, point) + self._add_point(updated_id, (int(x), int(y), int(z))) return updated_id - def sanitise_ids(self, neighbour_ids: npt.NDArray[np.uint64]) -> int: + def sanitise_ids(self, neighbour_ids: npt.NDArray[sid_np_type]) -> int: """ Get the smallest ID of all the structures that are connected to IDs in `neighbour_ids`. @@ -246,7 +308,7 @@ def sanitise_ids(self, neighbour_ids: npt.NDArray[np.uint64]) -> int: return int(updated_id) def merge_structures( - self, updated_id: int, neighbour_ids: npt.NDArray[np.uint64] + self, updated_id: int, neighbour_ids: npt.NDArray[sid_np_type] ) -> None: """ For all the neighbours, reassign all the points of neighbour to @@ -261,14 +323,16 @@ def merge_structures( # minimise ID so if neighbour with higher ID, reassign its points # to current if neighbour_id > updated_id: - self.add_point(updated_id, self.coords_maps[neighbour_id]) + self.coords_maps[updated_id].extend( + self.coords_maps[neighbour_id] + ) self.coords_maps.pop(neighbour_id) self.obsolete_ids[neighbour_id] = updated_id def structures_to_cells(self) -> np.ndarray: - cell_centres = np.empty((len(self.coords_maps.keys()), 3)) + cell_centres = np.empty((len(self.coords_maps), 3)) for idx, structure in enumerate(self.coords_maps.values()): - p = get_structure_centre(structure) + p = _get_structure_centre(structure) cell_centres[idx] = p return cell_centres diff --git a/cellfinder/core/detect/filters/volume/structure_splitting.py b/cellfinder/core/detect/filters/volume/structure_splitting.py index 0573e615..f0018df8 100644 --- a/cellfinder/core/detect/filters/volume/structure_splitting.py +++ b/cellfinder/core/detect/filters/volume/structure_splitting.py @@ -71,7 +71,7 @@ def ball_filter_imgs( """ # OPTIMISE: reuse ball filter instance - good_tiles_mask = np.ones((1, 1, volume.shape[2]), dtype=bool) + good_tiles_mask = np.ones((1, 1, volume.shape[2]), dtype=np.bool_) plane_width, plane_height = volume.shape[:2] diff --git a/cellfinder/core/detect/filters/volume/volume_filter.py b/cellfinder/core/detect/filters/volume/volume_filter.py index d64ae71a..949f9f91 100644 --- a/cellfinder/core/detect/filters/volume/volume_filter.py +++ b/cellfinder/core/detect/filters/volume/volume_filter.py @@ -1,5 +1,7 @@ import math +import multiprocessing.pool import os +from functools import partial from queue import Queue from threading import Lock from typing import Any, Callable, List, Optional, Tuple @@ -77,7 +79,7 @@ def process( locks: List[Lock], *, callback: Callable[[int], None], - ) -> List[Cell]: + ) -> None: progress_bar = tqdm(total=self.n_planes, desc="Processing planes") for z in range(self.n_planes): # Get result from the queue. @@ -108,11 +110,13 @@ def process( progress_bar.close() logger.debug("3D filter done") - return self.get_results() def _run_filter(self) -> None: logger.debug(f"🏐 Ball filtering plane {self.z}") - self.ball_filter.walk() + # filtering original images, the images should be large enough in x/y + # to benefit from parallelization. Note: don't pass arg as keyword arg + # because numba gets stuck (probably b/c class jit is new) + self.ball_filter.walk(True) middle_plane = self.ball_filter.get_middle_plane() if self.save_planes: @@ -134,7 +138,7 @@ def save_plane(self, plane: np.ndarray) -> None: f_path = os.path.join(self.plane_directory, plane_name) tifffile.imsave(f_path, plane.T) - def get_results(self) -> List[Cell]: + def get_results(self, worker_pool: multiprocessing.Pool) -> List[Cell]: logger.info("Splitting cell clusters and writing results") max_cell_volume = sphere_volume( @@ -142,62 +146,57 @@ def get_results(self) -> List[Cell]: ) cells = [] + needs_split = [] + structures = self.cell_detector.get_structures().items() + logger.debug(f"Processing {len(structures)} found cells") - logger.debug( - f"Processing {len(self.cell_detector.coords_maps.items())} cells" - ) - for cell_id, cell_points in self.cell_detector.coords_maps.items(): + # first get all the cells that are not clusters + for cell_id, cell_points in structures: cell_volume = len(cell_points) if cell_volume < max_cell_volume: cell_centre = get_structure_centre(cell_points) - cells.append( - Cell( - ( - cell_centre[0], - cell_centre[1], - cell_centre[2], - ), - Cell.UNKNOWN, - ) - ) + cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN)) else: if cell_volume < self.max_cluster_size: - try: - cell_centres = split_cells( - cell_points, outlier_keep=self.outlier_keep - ) - except (ValueError, AssertionError) as err: - raise StructureSplitException( - f"Cell {cell_id}, error; {err}" - ) - for cell_centre in cell_centres: - cells.append( - Cell( - ( - cell_centre[0], - cell_centre[1], - cell_centre[2], - ), - Cell.UNKNOWN, - ) - ) + needs_split.append((cell_id, cell_points)) else: cell_centre = get_structure_centre(cell_points) - cells.append( - Cell( - ( - cell_centre[0], - cell_centre[1], - cell_centre[2], - ), - Cell.ARTIFACT, - ) - ) - - logger.debug("Finished splitting cell clusters.") + cells.append(Cell(cell_centre.tolist(), Cell.ARTIFACT)) + + if not needs_split: + logger.debug("Finished splitting cell clusters - none found") + return cells + + # now split clusters into cells + logger.debug(f"Splitting {len(needs_split)} clusters") + progress_bar = tqdm( + total=len(needs_split), desc="Splitting cell clusters" + ) + + # we are not returning Cell instances from func because it'd be pickled + # by multiprocess which slows it down + func = partial(_split_cells, outlier_keep=self.outlier_keep) + for cell_centres in worker_pool.imap_unordered(func, needs_split): + for cell_centre in cell_centres: + cells.append(Cell(cell_centre.tolist(), Cell.UNKNOWN)) + progress_bar.update() + + progress_bar.close() + logger.debug( + f"Finished splitting cell clusters. Found {len(cells)} total cells" + ) + return cells +def _split_cells(arg, outlier_keep): + cell_id, cell_points = arg + try: + return split_cells(cell_points, outlier_keep=outlier_keep) + except (ValueError, AssertionError) as err: + raise StructureSplitException(f"Cell {cell_id}, error; {err}") + + def sphere_volume(radius: float) -> float: return (4 / 3) * math.pi * radius**3 diff --git a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py index 9895e2c9..d1e1af7a 100644 --- a/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py +++ b/tests/core/test_unit/test_detect/test_filters/test_volume_filters/test_structure_detection.py @@ -145,5 +145,5 @@ def test_detection(dtype, pixels, expected_coords): for plane in data: previous_plane = detector.process(plane, previous_plane) - coords = detector.get_coords_dict() + coords = detector.get_structures() assert coords_to_points(coords) == expected_coords From eeffd7873b561ffab8926d9b737feb7676e05022 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 May 2024 10:23:47 +0100 Subject: [PATCH 35/50] [pre-commit.ci] pre-commit autoupdate (#412) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v4.5.0 → v4.6.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.5.0...v4.6.0) - [github.com/astral-sh/ruff-pre-commit: v0.3.5 → v0.4.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.3.5...v0.4.3) - [github.com/psf/black: 24.3.0 → 24.4.2](https://github.com/psf/black/compare/24.3.0...24.4.2) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0bff4083..db6568cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-docstring-first - id: check-executables-have-shebangs @@ -16,10 +16,10 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.5 + rev: v0.4.3 hooks: - id: ruff - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 24.4.2 hooks: - id: black From 8de63bd6ccadab0faad020ee760d3dd8ab45e6fe Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov <61896994+IgorTatarnikov@users.noreply.github.com> Date: Wed, 8 May 2024 11:28:04 +0100 Subject: [PATCH 36/50] Apply suggestions from code review Co-authored-by: sfmig <33267254+sfmig@users.noreply.github.com> --- cellfinder/__init__.py | 8 ++++---- cellfinder/core/classify/resnet.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index 700a4129..a5e954d3 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -25,19 +25,19 @@ # If no backend is configured and installed for Keras, tools cannot be used -# Check backend is configured +# Check if backend is configured. If not, set to "torch" if not os.getenv("KERAS_BACKEND"): os.environ["KERAS_BACKEND"] = "torch" warnings.warn("Keras backend not configured, automatically set to Torch") # Check backend is installed -if os.getenv("KERAS_BACKEND") in ["tensorflow", "jax", "torch"]: - backend = os.getenv("KERAS_BACKEND") +backend = os.getenv("KERAS_BACKEND") +if backend in ["tensorflow", "jax", "torch"]: try: BACKEND_VERSION = version(backend) except PackageNotFoundError as e: raise PackageNotFoundError( - f"{backend}, ({backend}) set as Keras backend " + f"{backend} set as Keras backend " f"but not installed" ) from e else: diff --git a/cellfinder/core/classify/resnet.py b/cellfinder/core/classify/resnet.py index e0bc98d4..e172f712 100644 --- a/cellfinder/core/classify/resnet.py +++ b/cellfinder/core/classify/resnet.py @@ -1,6 +1,6 @@ from typing import Callable, Dict, List, Literal, Optional, Tuple, Union -import keras.config +import keras from keras import ( KerasTensor as Tensor, ) From 860942c800ba687b2cc5ddbb5241cddc2a60baf4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 May 2024 10:28:59 +0000 Subject: [PATCH 37/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cellfinder/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index a5e954d3..a10b386b 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -37,8 +37,7 @@ BACKEND_VERSION = version(backend) except PackageNotFoundError as e: raise PackageNotFoundError( - f"{backend} set as Keras backend " - f"but not installed" + f"{backend} set as Keras backend " f"but not installed" ) from e else: raise PackageNotFoundError( From 5f4882ef3946723a83e520b54b7ed711c29d6baa Mon Sep 17 00:00:00 2001 From: Adam Tyson Date: Thu, 9 May 2024 09:54:41 +0100 Subject: [PATCH 38/50] Simplify model download (#414) * Simplify model download * Update model cache --- .github/workflows/test_and_deploy.yml | 6 +- cellfinder/__init__.py | 3 + cellfinder/core/download/cli.py | 71 +++++++------- cellfinder/core/download/download.py | 100 +++++++++----------- cellfinder/core/download/models.py | 49 ---------- cellfinder/core/main.py | 2 +- cellfinder/core/tools/prep.py | 21 ++-- cellfinder/core/tools/source_files.py | 8 +- cellfinder/core/train/train_yml.py | 10 +- cellfinder/napari/train/train_containers.py | 6 +- tests/core/conftest.py | 8 +- 11 files changed, 118 insertions(+), 166 deletions(-) delete mode 100644 cellfinder/core/download/models.py diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 5197f102..3655b0ba 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -57,7 +57,7 @@ jobs: uses: actions/cache@v3 with: path: "~/.cellfinder" - key: models-${{ hashFiles('~/.cellfinder/**') }} + key: models-${{ hashFiles('~/.brainglobe/**') }} # Setup pyqt libraries - name: Setup qtpy libraries uses: tlambert03/setup-qt-libs@v1 @@ -83,7 +83,7 @@ jobs: uses: actions/cache@v3 with: path: "~/.cellfinder" - key: models-${{ hashFiles('~/.cellfinder/**') }} + key: models-${{ hashFiles('~/.brainglobe/**') }} # Setup pyqt libraries - name: Setup qtpy libraries uses: tlambert03/setup-qt-libs@v1 @@ -108,7 +108,7 @@ jobs: uses: actions/cache@v3 with: path: "~/.cellfinder" - key: models-${{ hashFiles('~/.cellfinder/**') }} + key: models-${{ hashFiles('~/.brainglobe/**') }} - name: Checkout brainglobe-workflows uses: actions/checkout@v3 diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index 9971f648..fcd51af8 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -1,4 +1,5 @@ from importlib.metadata import PackageNotFoundError, version +from pathlib import Path try: __version__ = version("cellfinder") @@ -22,3 +23,5 @@ __author__ = "Adam Tyson, Christian Niedworok, Charly Rousseau" __license__ = "BSD-3-Clause" + +DEFAULT_CELLFINDER_DIRECTORY = Path.home() / ".brainglobe" / "cellfinder" diff --git a/cellfinder/core/download/cli.py b/cellfinder/core/download/cli.py index c97bbd18..0446cce6 100644 --- a/cellfinder/core/download/cli.py +++ b/cellfinder/core/download/cli.py @@ -1,17 +1,29 @@ -import tempfile from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from pathlib import Path -from cellfinder.core.download import models -from cellfinder.core.download.download import amend_user_configuration +from cellfinder.core.download.download import ( + DEFAULT_DOWNLOAD_DIRECTORY, + amend_user_configuration, + download_models, +) -home = Path.home() -DEFAULT_DOWNLOAD_DIRECTORY = home / ".cellfinder" -temp_dir = tempfile.TemporaryDirectory() -temp_dir_path = Path(temp_dir.name) +def download_parser(parser: ArgumentParser) -> ArgumentParser: + """ + Configure the argument parser for downloading files. + + Parameters + ---------- + parser : ArgumentParser + The argument parser to configure. + + Returns + ------- + ArgumentParser + The configured argument parser. + + """ -def download_directory_parser(parser): parser.add_argument( "--install-path", dest="install_path", @@ -19,29 +31,12 @@ def download_directory_parser(parser): default=DEFAULT_DOWNLOAD_DIRECTORY, help="The path to install files to.", ) - parser.add_argument( - "--download-path", - dest="download_path", - type=Path, - default=temp_dir_path, - help="The path to download files into.", - ) parser.add_argument( "--no-amend-config", dest="no_amend_config", action="store_true", help="Don't amend the config file", ) - return parser - - -def model_parser(parser): - parser.add_argument( - "--no-models", - dest="no_models", - action="store_true", - help="Don't download the model", - ) parser.add_argument( "--model", dest="model", @@ -52,17 +47,29 @@ def model_parser(parser): return parser -def download_parser(): +def get_parser() -> ArgumentParser: + """ + Create an argument parser for downloading files. + + Returns + ------- + ArgumentParser + The configured argument parser. + + """ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser = model_parser(parser) - parser = download_directory_parser(parser) + parser = download_parser(parser) return parser -def main(): - args = download_parser().parse_args() - if not args.no_models: - model_path = models.main(args.model, args.install_path) +def main() -> None: + """ + Run the main download function, and optionally amend the user + configuration. + + """ + args = get_parser().parse_args() + model_path = download_models(args.model, args.install_path) if not args.no_amend_config: amend_user_configuration(new_model_path=model_path) diff --git a/cellfinder/core/download/download.py b/cellfinder/core/download/download.py index cd96616f..42806287 100644 --- a/cellfinder/core/download/download.py +++ b/cellfinder/core/download/download.py @@ -1,79 +1,67 @@ import os -import shutil -import tarfile -import urllib.request from pathlib import Path +from typing import Literal +import pooch from brainglobe_utils.general.config import get_config_obj -from brainglobe_utils.general.system import disk_free_gb +from cellfinder import DEFAULT_CELLFINDER_DIRECTORY from cellfinder.core.tools.source_files import ( default_configuration_path, user_specific_configuration_path, ) +DEFAULT_DOWNLOAD_DIRECTORY = DEFAULT_CELLFINDER_DIRECTORY / "models" -class DownloadError(Exception): - pass +MODEL_URL = "https://gin.g-node.org/cellfinder/models/raw/master" -def download_file(destination_path, file_url, filename): - direct_download = True - file_url = file_url.format(int(direct_download)) - print(f"Downloading file: {filename}") - with urllib.request.urlopen(file_url) as response: - with open(destination_path, "wb") as outfile: - shutil.copyfileobj(response, outfile) +model_filenames = { + "resnet50_tv": "resnet50_tv.h5", + "resnet50_all": "resnet50_weights.h5", +} +model_hashes = { + "resnet50_tv": "63d36af456640590ba6c896dc519f9f29861015084f4c40777a54c18c1fc4edd", # noqa: E501 + "resnet50_all": None, +} -def extract_file(tar_file_path, destination_path): - tar = tarfile.open(tar_file_path) - tar.extractall(path=destination_path) - tar.close() +model_type = Literal["resnet50_tv", "resnet50_all"] -# TODO: check that intermediate folders exist -def download( - download_path, - url, - file_name, - install_path=None, - download_requires=None, - extract_requires=None, -): - if not os.path.exists(os.path.dirname(download_path)): - raise DownloadError( - f"Could not find directory '{os.path.dirname(download_path)}' " - f"to download file: {file_name}" - ) - if (download_requires is not None) and ( - disk_free_gb(os.path.dirname(download_path)) < download_requires - ): - raise DownloadError( - f"Insufficient disk space in {os.path.dirname(download_path)} to" - f"download file: {file_name}" - ) +def download_models( + model_name: model_type, download_path: os.PathLike +) -> Path: + """ + For a given model name and download path, download the model file + and return the path to the downloaded file. + + Parameters + ---------- + model_name : model_type + The name of the model to be downloaded. + download_path : os.PathLike + The path where the model file will be downloaded. - if install_path is not None: - if not os.path.exists(install_path): - raise DownloadError( - f"Could not find directory '{install_path}' " - f"to extract file: {file_name}" - ) + Returns + ------- + Path + The path to the downloaded model file. + + """ - if (extract_requires is not None) and ( - disk_free_gb(install_path) < extract_requires - ): - raise DownloadError( - f"Insufficient disk space in {install_path} to" - f"extract file: {file_name}" - ) + download_path = Path(download_path) + filename = model_filenames[model_name] + model_path = pooch.retrieve( + url=f"{MODEL_URL}/{filename}", + known_hash=model_hashes[model_name], + path=download_path, + fname=filename, + progressbar=True, + ) - download_file(download_path, url, file_name) - if install_path is not None: - extract_file(download_path, install_path) - os.remove(download_path) + return Path(model_path) def amend_user_configuration(new_model_path=None) -> None: @@ -83,7 +71,7 @@ def amend_user_configuration(new_model_path=None) -> None: Parameters ---------- - new_model_path : str, optional + new_model_path : Path, optional The path to the new model configuration. """ print("(Over-)writing custom user configuration") diff --git a/cellfinder/core/download/models.py b/cellfinder/core/download/models.py deleted file mode 100644 index dbb0f3cb..00000000 --- a/cellfinder/core/download/models.py +++ /dev/null @@ -1,49 +0,0 @@ -import os -from pathlib import Path -from typing import Literal - -from cellfinder.core import logger -from cellfinder.core.download.download import download - -model_weight_urls = { - "resnet50_tv": "https://gin.g-node.org/cellfinder/models/raw/" - "master/resnet50_tv.h5", - "resnet50_all": "https://gin.g-node.org/cellfinder/models/raw/" - "master/resnet50_weights.h5", -} - -download_requirements_gb = { - "resnet50_tv": 0.18, - "resnet50_all": 0.18, -} - -model_type = Literal["resnet50_tv", "resnet50_all"] - - -def main(model_name: model_type, download_path: os.PathLike) -> Path: - """ - For a given model name and download path, download the model file - and return the path to the downloaded file. - """ - download_path = Path(download_path) - - model_weight_dir = download_path / "model_weights" - model_path = model_weight_dir / f"{model_name}.h5" - if not model_path.exists(): - model_weight_dir.mkdir(parents=True) - - logger.info( - f"Downloading '{model_name}' model. This may take a little while." - ) - - download( - model_path, - model_weight_urls[model_name], - model_name, - download_requires=download_requirements_gb[model_name], - ) - - else: - logger.info(f"Model already exists at {model_path}. Skipping download") - - return model_path diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index c74a9d44..926fe545 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -11,7 +11,7 @@ from brainglobe_utils.general.logging import suppress_specific_logs from cellfinder.core import logger -from cellfinder.core.download.models import model_type +from cellfinder.core.download.download import model_type from cellfinder.core.train.train_yml import depth_type tf_suppress_log_messages = [ diff --git a/cellfinder/core/tools/prep.py b/cellfinder/core/tools/prep.py index 1e0bccf3..a2625311 100644 --- a/cellfinder/core/tools/prep.py +++ b/cellfinder/core/tools/prep.py @@ -13,18 +13,19 @@ import cellfinder.core.tools.tf as tf_tools from cellfinder.core import logger -from cellfinder.core.download import models as model_download -from cellfinder.core.download.download import amend_user_configuration +from cellfinder.core.download.download import ( + DEFAULT_DOWNLOAD_DIRECTORY, + amend_user_configuration, + download_models, + model_type, +) from cellfinder.core.tools.source_files import user_specific_configuration_path -home = Path.home() -DEFAULT_INSTALL_PATH = home / ".cellfinder" - def prep_model_weights( model_weights: Optional[os.PathLike], install_path: Optional[os.PathLike], - model_name: model_download.model_type, + model_name: model_type, n_free_cpus: int, ) -> Path: n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) @@ -42,9 +43,9 @@ def prep_tensorflow(max_threads: int) -> None: def prep_models( model_weights_path: Optional[os.PathLike], install_path: Optional[os.PathLike], - model_name: model_download.model_type, + model_name: model_type, ) -> Path: - install_path = install_path or DEFAULT_INSTALL_PATH + install_path = install_path or DEFAULT_DOWNLOAD_DIRECTORY # if no model or weights, set default weights if model_weights_path is None: logger.debug("No model supplied, so using the default") @@ -53,13 +54,13 @@ def prep_models( if not Path(config_file).exists(): logger.debug("Custom config does not exist, downloading models") - model_path = model_download.main(model_name, install_path) + model_path = download_models(model_name, install_path) amend_user_configuration(new_model_path=model_path) model_weights = get_model_weights(config_file) if not model_weights.exists(): logger.debug("Model weights do not exist, downloading") - model_path = model_download.main(model_name, install_path) + model_path = download_models(model_name, install_path) amend_user_configuration(new_model_path=model_path) model_weights = get_model_weights(config_file) else: diff --git a/cellfinder/core/tools/source_files.py b/cellfinder/core/tools/source_files.py index 474cc51e..99fe4108 100644 --- a/cellfinder/core/tools/source_files.py +++ b/cellfinder/core/tools/source_files.py @@ -1,5 +1,7 @@ from pathlib import Path +from cellfinder import DEFAULT_CELLFINDER_DIRECTORY + def default_configuration_path(): """ @@ -17,11 +19,11 @@ def user_specific_configuration_path(): This function returns the path to the user-specific configuration file for cellfinder. The user-specific configuration file is located in the - user's home directory under the ".cellfinder" folder and is named - "cellfinder.conf.custom". + user's home directory under the ".brainglobe/cellfinder" folder + and is named "cellfinder.conf.custom". Returns: Path: The path to the custom configuration file. """ - return Path.home() / ".cellfinder" / "cellfinder.conf.custom" + return DEFAULT_CELLFINDER_DIRECTORY / "cellfinder.conf.custom" diff --git a/cellfinder/core/train/train_yml.py b/cellfinder/core/train/train_yml.py index fbb59968..bf916b3c 100644 --- a/cellfinder/core/train/train_yml.py +++ b/cellfinder/core/train/train_yml.py @@ -31,7 +31,7 @@ import cellfinder.core as program_for_log from cellfinder.core import logger from cellfinder.core.classify.resnet import layer_type -from cellfinder.core.tools.prep import DEFAULT_INSTALL_PATH +from cellfinder.core.download.download import DEFAULT_DOWNLOAD_DIRECTORY tf_suppress_log_messages = [ "sample_weight modes were coerced from", @@ -112,8 +112,7 @@ def misc_parse(parser): def training_parse(): from cellfinder.core.download.cli import ( - download_directory_parser, - model_parser, + download_parser, ) training_parser = ArgumentParser( @@ -223,8 +222,7 @@ def training_parse(): ) training_parser = misc_parse(training_parser) - training_parser = model_parser(training_parser) - training_parser = download_directory_parser(training_parser) + training_parser = download_parser(training_parser) args = training_parser.parse_args() return args @@ -306,7 +304,7 @@ def run( n_free_cpus=2, trained_model=None, model_weights=None, - install_path=DEFAULT_INSTALL_PATH, + install_path=DEFAULT_DOWNLOAD_DIRECTORY, model="resnet50_tv", network_depth="50", learning_rate=0.0001, diff --git a/cellfinder/napari/train/train_containers.py b/cellfinder/napari/train/train_containers.py index 73c6ae1e..7259f9a2 100644 --- a/cellfinder/napari/train/train_containers.py +++ b/cellfinder/napari/train/train_containers.py @@ -4,7 +4,7 @@ from magicgui.types import FileDialogMode -from cellfinder.core.download.models import model_weight_urls +from cellfinder.core.download.download import model_filenames from cellfinder.core.train.train_yml import models from cellfinder.napari.input_container import InputContainer from cellfinder.napari.utils import html_label_widget @@ -46,7 +46,7 @@ class OptionalNetworkInputs(InputContainer): trained_model: Optional[Path] = Path.home() model_weights: Optional[Path] = Path.home() model_depth: str = list(models.keys())[2] - pretrained_model: str = str(list(model_weight_urls.keys())[0]) + pretrained_model: str = str(list(model_filenames.keys())[0]) def as_core_arguments(self) -> dict: arguments = super().as_core_arguments() @@ -65,7 +65,7 @@ def widget_representation(cls) -> dict: ), pretrained_model=cls._custom_widget( "pretrained_model", - choices=list(model_weight_urls.keys()), + choices=list(model_filenames.keys()), ), ) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index f05ec88a..29e465ad 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -5,8 +5,10 @@ import pytest from skimage.filters import gaussian -from cellfinder.core.download import models -from cellfinder.core.tools.prep import DEFAULT_INSTALL_PATH +from cellfinder.core.download.download import ( + DEFAULT_DOWNLOAD_DIRECTORY, + download_models, +) @pytest.fixture(scope="session") @@ -35,7 +37,7 @@ def download_default_model(): Check that the classification model is already downloaded at the beginning of a pytest session. """ - models.main("resnet50_tv", DEFAULT_INSTALL_PATH) + download_models("resnet50_tv", DEFAULT_DOWNLOAD_DIRECTORY) @pytest.fixture(scope="session") From 906f4d60685efc31dcdc61551764ebccad828b34 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 10:43:29 +0100 Subject: [PATCH 39/50] Remove jax and tf tests --- .github/workflows/test_and_deploy.yml | 2 ++ pyproject.toml | 29 ++++++--------------------- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index c079df0d..0c41b4ae 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -37,6 +37,8 @@ jobs: name: Run package tests timeout-minutes: 60 runs-on: ${{ matrix.os }} + env: + PYTORCH_ENABLE_MPS_FALLBACK: "1" strategy: matrix: # Run all supported Python versions on linux diff --git a/pyproject.toml b/pyproject.toml index 645f22dc..16313642 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "scikit-image", "scikit-learn", "keras>=3.0.0", + "torch>=2.1.0", "tifffile", "tqdm", ] @@ -59,17 +60,6 @@ napari = [ "pooch >= 1", "qtpy", ] -# Keras backends -tf = [ - "tensorflow>=2.16.1", -] -jax = [ - "jax==0.4.20", - "jaxlib==0.4.20" -] -torch = [ - "torch>=2.1.0" -] [project.scripts] cellfinder_download = "cellfinder.core.download.cli:main" @@ -120,14 +110,14 @@ markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] legacy_tox_ini = """ # For more information about tox, see https://tox.readthedocs.io/en/latest/ [tox] -envlist = py{39,310,311}-{torch} +envlist = py{39,310,311} isolated_build = true [gh-actions] python = - 3.9: py39-{torch} # On GA python=3.9 job, run tox with the torch environment - 3.10: py310-{torch} # On GA python=3.10 job, run tox with the torch environment - 3.11: py311-{torch} # On GA python=3.11 job, run tox with the torch environment + 3.9: py39 + 3.10: py310 + 3.11: py311 [testenv] commands = python -m pytest -v --color=yes @@ -142,13 +132,8 @@ deps = pytest-qt extras = napari - tf: tf - jax: jax - torch: torch setenv = - tf: KERAS_BACKEND = tensorflow - jax: KERAS_BACKEND = jax - torch: KERAS_BACKEND = torch + KERAS_BACKEND = torch passenv = NUMBA_DISABLE_JIT CI @@ -157,6 +142,4 @@ passenv = XAUTHORITY NUMPY_EXPERIMENTAL_ARRAY_FUNCTION PYVISTA_OFF_SCREEN -platform = - tf: linux|darwin # skip TF backend on windows """ From 7e40ffcb86cc5a723f1cde280158b04a3b8154f6 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 11:51:20 +0100 Subject: [PATCH 40/50] Standardise the data types for inputs to all be float32 --- cellfinder/core/classify/cube_generator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cellfinder/core/classify/cube_generator.py b/cellfinder/core/classify/cube_generator.py index 4fb8b5e6..b3e9dc21 100644 --- a/cellfinder/core/classify/cube_generator.py +++ b/cellfinder/core/classify/cube_generator.py @@ -259,7 +259,8 @@ def __generate_cubes( (number_images,) + (self.cube_height, self.cube_width, self.cube_depth) + (self.channels,) - ) + ), + dtype=np.float32, ) for idx, cell in enumerate(cell_batch): @@ -438,7 +439,8 @@ def __generate_cubes( ) -> np.ndarray: number_images = len(list_signal_tmp) images = np.empty( - ((number_images,) + self.im_shape + (self.channels,)) + ((number_images,) + self.im_shape + (self.channels,)), + dtype=np.float32, ) for idx, signal_im in enumerate(list_signal_tmp): @@ -447,7 +449,7 @@ def __generate_cubes( images, idx, signal_im, background_im ) - return images.astype(np.float16) + return images def __populate_array_with_cubes( self, From 933e5ddaa35d3c2f5d9d8af916bf1bf0ee300457 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 14:07:49 +0100 Subject: [PATCH 41/50] Force torch to use CPU on arm based macOS during tests --- tests/core/conftest.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 83778a9f..7df4acdc 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -1,13 +1,24 @@ +import platform from typing import Tuple import numpy as np import pytest +import torch from skimage.filters import gaussian from cellfinder.core.download import models from cellfinder.core.tools.prep import DEFAULT_INSTALL_PATH +@pytest.fixture(scope="session", autouse=True) +def macos_use_cpu_only(): + """ + Ensure torch only uses the CPU when running on arm based macOS. + """ + if platform.system() == "Darwin" and platform.processor() == "arm": + torch.set_default_device("cpu") + + @pytest.fixture(scope="session") def download_default_model(): """ From ada5f77f8d018d3282d910ce5e11b1547f26d18f Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 14:31:14 +0100 Subject: [PATCH 42/50] Added PYTORCH_MPS_HIGH_WATERMARK_RATION env variable --- .github/workflows/test_and_deploy.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 0c41b4ae..a1f87e3f 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -39,6 +39,7 @@ jobs: runs-on: ${{ matrix.os }} env: PYTORCH_ENABLE_MPS_FALLBACK: "1" + PYTORCH_MPS_HIGH_WATERMARK_RATIO: "0.0" strategy: matrix: # Run all supported Python versions on linux From 546f223ba5f37f80c10fd301df884bf3cf357921 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 15:02:00 +0100 Subject: [PATCH 43/50] Set env variables in test setup --- tests/core/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 7df4acdc..82a53998 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -1,9 +1,9 @@ +import os import platform from typing import Tuple import numpy as np import pytest -import torch from skimage.filters import gaussian from cellfinder.core.download import models @@ -16,7 +16,8 @@ def macos_use_cpu_only(): Ensure torch only uses the CPU when running on arm based macOS. """ if platform.system() == "Darwin" and platform.processor() == "arm": - torch.set_default_device("cpu") + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" @pytest.fixture(scope="session") From 0546a56f93377f47a3a4c0f59c9e04fc63268390 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 15:39:41 +0100 Subject: [PATCH 44/50] Try to set the default device to cpu in the test itself --- tests/core/test_integration/test_detection.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index 6e20d679..c1a2e8b7 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -4,6 +4,7 @@ import brainglobe_utils.IO.cells as cell_io import numpy as np import pytest +import torch from brainglobe_utils.general.system import get_num_processes from cellfinder.core.main import main @@ -112,6 +113,7 @@ def test_callbacks( signal_array, background_array, cpus_to_leave_free: int = 0 ): # 20 is minimum number of planes needed to find > 0 cells + torch.set_default_device("cpu") signal_array = signal_array[0:20] background_array = background_array[0:20] From 8eb5ee35103f37f67103d89d17d284fe5b037357 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 15:45:41 +0100 Subject: [PATCH 45/50] Add device call to Conv3D to force cpu --- cellfinder/core/classify/resnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cellfinder/core/classify/resnet.py b/cellfinder/core/classify/resnet.py index e172f712..e8cb6d3b 100644 --- a/cellfinder/core/classify/resnet.py +++ b/cellfinder/core/classify/resnet.py @@ -131,6 +131,7 @@ def non_residual_block( strides=strides, use_bias=use_bias, name="conv1", + device="cpu", )(x) x = BatchNormalization(axis=axis, epsilon=bn_epsilon, name="conv1_bn")(x) x = Activation(activation, name="conv1_activation")(x) From b995bb365fbbb1fcd072a75a0965f3b0bd99a8f1 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 16:03:50 +0100 Subject: [PATCH 46/50] Revert changes, request one cpu left free --- cellfinder/core/classify/resnet.py | 1 - tests/core/conftest.py | 12 ------------ tests/core/test_integration/test_detection.py | 4 +--- 3 files changed, 1 insertion(+), 16 deletions(-) diff --git a/cellfinder/core/classify/resnet.py b/cellfinder/core/classify/resnet.py index e8cb6d3b..e172f712 100644 --- a/cellfinder/core/classify/resnet.py +++ b/cellfinder/core/classify/resnet.py @@ -131,7 +131,6 @@ def non_residual_block( strides=strides, use_bias=use_bias, name="conv1", - device="cpu", )(x) x = BatchNormalization(axis=axis, epsilon=bn_epsilon, name="conv1_bn")(x) x = Activation(activation, name="conv1_activation")(x) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 82a53998..83778a9f 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -1,5 +1,3 @@ -import os -import platform from typing import Tuple import numpy as np @@ -10,16 +8,6 @@ from cellfinder.core.tools.prep import DEFAULT_INSTALL_PATH -@pytest.fixture(scope="session", autouse=True) -def macos_use_cpu_only(): - """ - Ensure torch only uses the CPU when running on arm based macOS. - """ - if platform.system() == "Darwin" and platform.processor() == "arm": - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" - - @pytest.fixture(scope="session") def download_default_model(): """ diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index c1a2e8b7..940e62f4 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -4,7 +4,6 @@ import brainglobe_utils.IO.cells as cell_io import numpy as np import pytest -import torch from brainglobe_utils.general.system import get_num_processes from cellfinder.core.main import main @@ -110,10 +109,9 @@ def test_detection_small_planes( def test_callbacks( - signal_array, background_array, cpus_to_leave_free: int = 0 + signal_array, background_array, cpus_to_leave_free: int = 1 ): # 20 is minimum number of planes needed to find > 0 cells - torch.set_default_device("cpu") signal_array = signal_array[0:20] background_array = background_array[0:20] From 78d1588f479e5373c9f5fdd968c5b31609243d10 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 16:12:30 +0100 Subject: [PATCH 47/50] Revers the numb cores, don't use arm based mac runner --- .github/workflows/test_and_deploy.yml | 5 +---- tests/core/test_integration/test_detection.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index a1f87e3f..b4460f24 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -37,9 +37,6 @@ jobs: name: Run package tests timeout-minutes: 60 runs-on: ${{ matrix.os }} - env: - PYTORCH_ENABLE_MPS_FALLBACK: "1" - PYTORCH_MPS_HIGH_WATERMARK_RATIO: "0.0" strategy: matrix: # Run all supported Python versions on linux @@ -47,7 +44,7 @@ jobs: python-version: ["3.9", "3.10", "3.11"] # Include one macos run include: - - os: macos-latest + - os: macos-13 python-version: "3.10" - os: windows-latest python-version: "3.10" diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index 940e62f4..6e20d679 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -109,7 +109,7 @@ def test_detection_small_planes( def test_callbacks( - signal_array, background_array, cpus_to_leave_free: int = 1 + signal_array, background_array, cpus_to_leave_free: int = 0 ): # 20 is minimum number of planes needed to find > 0 cells signal_array = signal_array[0:20] From 806f52bc40601f5fc6db787b1671beeb0f4af1fd Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 16:33:37 +0100 Subject: [PATCH 48/50] Merged main, removed torch flags on cellfinder install for guards and brainmapper --- .github/workflows/test_and_deploy.yml | 8 +++----- .github/workflows/test_include_guard.yaml | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index c569ea9a..5918a8f3 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -42,12 +42,10 @@ jobs: # Run all supported Python versions on linux os: [ubuntu-latest] python-version: ["3.9", "3.10", "3.11"] - # Include one windows, one macos run each for M1 (latest) and Intel (13) + # Include one windows, one macos run (intel based macOS 13 runner) include: - os: macos-13 python-version: "3.10" - - os: macos-latest - python-version: "3.10" - os: windows-latest python-version: "3.10" @@ -125,8 +123,8 @@ jobs: - name: Install test dependencies run: | python -m pip install --upgrade pip wheel - # Install cellfinder from the latest SHA on this branch (Keras with torch backend) - python -m pip install "cellfinder[torch] @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA" + # Install cellfinder from the latest SHA on this branch + python -m pip install "cellfinder @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA" # Install checked out copy of brainglobe-workflows python -m pip install .[dev] diff --git a/.github/workflows/test_include_guard.yaml b/.github/workflows/test_include_guard.yaml index f080e210..07a5eda8 100644 --- a/.github/workflows/test_include_guard.yaml +++ b/.github/workflows/test_include_guard.yaml @@ -24,8 +24,8 @@ jobs: with: python-version: '3.10' - - name: Install cellfinder via pip, specifying torch as keras' backend - run: python -m pip install -e ".[torch]" + - name: Install cellfinder via pip + run: python -m pip install -e "." - name: Test (working) import uses: jannekem/run-python-script-action@v1 From a38257bd38c91c2f4637f90ba7d1c25b9f57d1c1 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 17:48:43 +0100 Subject: [PATCH 49/50] Lowercase Torch --- cellfinder/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index fe5c067d..a0467820 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -29,7 +29,7 @@ # Check if backend is configured. If not, set to "torch" if not os.getenv("KERAS_BACKEND"): os.environ["KERAS_BACKEND"] = "torch" - warnings.warn("Keras backend not configured, automatically set to Torch") + warnings.warn("Keras backend not configured, automatically set to torch") # Check backend is installed backend = os.getenv("KERAS_BACKEND") From 5a1f0a8d3597462f6d1e14c3a416728aadeb83d9 Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov Date: Thu, 9 May 2024 18:16:18 +0100 Subject: [PATCH 50/50] Change cache directory --- .github/workflows/test_and_deploy.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 5918a8f3..931037fa 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -54,7 +54,7 @@ jobs: - name: Cache Keras model uses: actions/cache@v3 with: - path: "~/.cellfinder" + path: "~/.brainglobe" key: models-${{ hashFiles('~/.brainglobe/**') }} # Setup pyqt libraries - name: Setup qtpy libraries @@ -80,7 +80,7 @@ jobs: - name: Cache Keras model uses: actions/cache@v3 with: - path: "~/.cellfinder" + path: "~/.brainglobe" key: models-${{ hashFiles('~/.brainglobe/**') }} # Setup pyqt libraries - name: Setup qtpy libraries @@ -107,7 +107,7 @@ jobs: - name: Cache Keras model uses: actions/cache@v3 with: - path: "~/.cellfinder" + path: "~/.brainglobe" key: models-${{ hashFiles('~/.brainglobe/**') }} - name: Checkout brainglobe-workflows