diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 78be7bb5..0b51bb58 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -37,12 +37,15 @@ jobs: name: Run package tests timeout-minutes: 60 runs-on: ${{ matrix.os }} + env: + KERAS_BACKEND: torch + CELLFINDER_TEST_DEVICE: cpu strategy: matrix: # Run all supported Python versions on linux os: [ubuntu-latest] - python-version: ["3.9", "3.10"] - # Include one windows, one macos run each for M1 (latest) and Intel (13) + python-version: ["3.9", "3.10", "3.11"] + # Include one windows and two macOS (intel based and arm based) runs include: - os: macos-13 python-version: "3.10" @@ -80,11 +83,13 @@ jobs: NUMBA_DISABLE_JIT: "1" steps: - - name: Cache tensorflow model + - name: Cache brainglobe directory uses: actions/cache@v3 with: - path: "~/.cellfinder" - key: models-${{ hashFiles('~/.brainglobe/**') }} + path: | # ensure we don't cache any interrupted atlas download and extraction, if e.g. we cancel the workflow manually + ~/.brainglobe + !~/.brainglobe/atlas.tar.gz + key: brainglobe # Setup pyqt libraries - name: Setup qtpy libraries uses: tlambert03/setup-qt-libs@v1 @@ -104,13 +109,17 @@ jobs: name: Run brainmapper tests to check for breakages timeout-minutes: 60 runs-on: ubuntu-latest + env: + KERAS_BACKEND: torch + CELLFINDER_TEST_DEVICE: cpu steps: - - name: Cache tensorflow model + - name: Cache brainglobe directory uses: actions/cache@v3 with: - path: "~/.cellfinder" - key: models-${{ hashFiles('~/.brainglobe/**') }} - + path: | # ensure we don't cache any interrupted atlas download and extraction, if e.g. we cancel the workflow manually + ~/.brainglobe + !~/.brainglobe/atlas.tar.gz + key: brainglobe - name: Checkout brainglobe-workflows uses: actions/checkout@v3 with: @@ -124,8 +133,9 @@ jobs: - name: Install test dependencies run: | python -m pip install --upgrade pip wheel - # Install latest SHA on this brainglobe-workflows branch - python -m pip install git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA + # Install cellfinder from the latest SHA on this branch + python -m pip install "cellfinder @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA" + # Install checked out copy of brainglobe-workflows python -m pip install .[dev] diff --git a/.github/workflows/test_include_guard.yaml b/.github/workflows/test_include_guard.yaml index 70415a2e..07a5eda8 100644 --- a/.github/workflows/test_include_guard.yaml +++ b/.github/workflows/test_include_guard.yaml @@ -1,5 +1,5 @@ -name: Test Tensorflow include guards -# These tests check that the include guards checking for tensorflow's availability +name: Test Keras include guards +# These tests check that the include guards checking for Keras availability # behave as expected on ubuntu and macOS. on: @@ -9,7 +9,7 @@ on: - main jobs: - tensorflow_guards: + keras_guards: name: Test include guards strategy: matrix: @@ -24,24 +24,21 @@ jobs: with: python-version: '3.10' - - name: Install via pip - run: python -m pip install -e . + - name: Install cellfinder via pip + run: python -m pip install -e "." - name: Test (working) import uses: jannekem/run-python-script-action@v1 + env: + KERAS_BACKEND: torch with: fail-on-error: true script: | import cellfinder.core import cellfinder.napari - - name: Uninstall tensorflow-macos on Mac M1 - if: matrix.os == 'macos-latest' - run: python -m pip uninstall -y tensorflow-macos - - - name: Uninstall tensorflow on Ubuntu - if: matrix.os == 'ubuntu-latest' - run: python -m pip uninstall -y tensorflow + - name: Uninstall keras + run: python -m pip uninstall -y keras - name: Test (broken) import id: broken_import diff --git a/cellfinder/__init__.py b/cellfinder/__init__.py index fcd51af8..405524c9 100644 --- a/cellfinder/__init__.py +++ b/cellfinder/__init__.py @@ -1,25 +1,31 @@ +import os from importlib.metadata import PackageNotFoundError, version from pathlib import Path +# Check cellfinder is installed try: __version__ = version("cellfinder") except PackageNotFoundError as e: raise PackageNotFoundError("cellfinder package not installed") from e -# If tensorflow is not present, tools cannot be used. +# If Keras is not present, tools cannot be used. # Throw an error in this case to prevent invocation of functions. try: - TF_VERSION = version("tensorflow") + KERAS_VERSION = version("keras") except PackageNotFoundError as e: - try: - TF_VERSION = version("tensorflow-macos") - except PackageNotFoundError as e: - raise PackageNotFoundError( - f"cellfinder tools cannot be invoked without tensorflow. " - f"Please install tensorflow into your environment to use cellfinder tools. " - f"For more information, please see " - f"https://github.com/brainglobe/brainglobe-meta#readme." - ) from e + raise PackageNotFoundError( + f"cellfinder tools cannot be invoked without Keras. " + f"Please install Keras with a backend into your environment " + f"to use cellfinder tools. " + f"For more information on Keras backends, please see " + f"https://keras.io/getting_started/#installing-keras-3." + f"For more information on brainglobe, please see " + f"https://github.com/brainglobe/brainglobe-meta#readme." + ) from e + + +# Set the Keras backend to torch +os.environ["KERAS_BACKEND"] = "torch" __author__ = "Adam Tyson, Christian Niedworok, Charly Rousseau" __license__ = "BSD-3-Clause" diff --git a/cellfinder/core/classify/classify.py b/cellfinder/core/classify/classify.py index 86b709f5..1da4dabe 100644 --- a/cellfinder/core/classify/classify.py +++ b/cellfinder/core/classify/classify.py @@ -1,10 +1,10 @@ import os from typing import Any, Callable, Dict, List, Optional, Tuple +import keras import numpy as np from brainglobe_utils.cells.cells import Cell from brainglobe_utils.general.system import get_num_processes -from tensorflow import keras from cellfinder.core import logger, types from cellfinder.core.classify.cube_generator import CubeGeneratorFromFile @@ -48,9 +48,7 @@ def main( callbacks = None # Too many workers doesn't increase speed, and uses huge amounts of RAM - workers = get_num_processes( - min_free_cpu_cores=n_free_cpus, n_max_processes=max_workers - ) + workers = get_num_processes(min_free_cpu_cores=n_free_cpus) logger.debug("Initialising cube generator") inference_generator = CubeGeneratorFromFile( @@ -63,6 +61,8 @@ def main( cube_width=cube_width, cube_height=cube_height, cube_depth=cube_depth, + use_multiprocessing=False, + workers=workers, ) model = get_model( @@ -73,10 +73,9 @@ def main( ) logger.info("Running inference") + # in Keras 3.0 multiprocessing params are specified in the generator predictions = model.predict( inference_generator, - use_multiprocessing=True, - workers=workers, verbose=True, callbacks=callbacks, ) diff --git a/cellfinder/core/classify/cube_generator.py b/cellfinder/core/classify/cube_generator.py index 56f226cd..4a24467f 100644 --- a/cellfinder/core/classify/cube_generator.py +++ b/cellfinder/core/classify/cube_generator.py @@ -2,13 +2,13 @@ from random import shuffle from typing import Dict, List, Optional, Tuple, Union +import keras import numpy as np -import tensorflow as tf from brainglobe_utils.cells.cells import Cell, group_cells_by_z from brainglobe_utils.general.numerical import is_even +from keras.utils import Sequence from scipy.ndimage import zoom from skimage.io import imread -from tensorflow.keras.utils import Sequence from cellfinder.core import types from cellfinder.core.classify.augment import AugmentationParameters, augment @@ -56,7 +56,14 @@ def __init__( translate: Tuple[float, float, float] = (0.05, 0.05, 0.05), shuffle: bool = False, interpolation_order: int = 2, + *args, + **kwargs, ): + # pass any additional arguments not specified in signature to the + # constructor of the superclass (e.g.: `use_multiprocessing` or + # `workers`) + super().__init__(*args, **kwargs) + self.points = points self.signal_array = signal_array self.background_array = background_array @@ -218,10 +225,10 @@ def __getitem__(self, index: int) -> Union[ if self.train: batch_labels = [cell.type - 1 for cell in cell_batch] - batch_labels = tf.keras.utils.to_categorical( + batch_labels = keras.utils.to_categorical( batch_labels, num_classes=self.classes ) - return images, batch_labels + return images, batch_labels.astype(np.float32) elif self.extract: batch_info = self.__get_batch_dict(cell_batch) return images, batch_info @@ -252,7 +259,8 @@ def __generate_cubes( (number_images,) + (self.cube_height, self.cube_width, self.cube_depth) + (self.channels,) - ) + ), + dtype=np.float32, ) for idx, cell in enumerate(cell_batch): @@ -350,7 +358,14 @@ def __init__( translate: Tuple[float, float, float] = (0.2, 0.2, 0.2), train: bool = False, # also return labels interpolation_order: int = 2, + *args, + **kwargs, ): + # pass any additional arguments not specified in signature to the + # constructor of the superclass (e.g.: `use_multiprocessing` or + # `workers`) + super().__init__(*args, **kwargs) + self.im_shape = shape self.batch_size = batch_size self.labels = labels @@ -410,10 +425,10 @@ def __getitem__(self, index: int) -> Union[ if self.train and self.labels is not None: batch_labels = [self.labels[k] for k in indexes] - batch_labels = tf.keras.utils.to_categorical( + batch_labels = keras.utils.to_categorical( batch_labels, num_classes=self.classes ) - return images, batch_labels + return images, batch_labels.astype(np.float32) else: return images @@ -424,7 +439,8 @@ def __generate_cubes( ) -> np.ndarray: number_images = len(list_signal_tmp) images = np.empty( - ((number_images,) + self.im_shape + (self.channels,)) + ((number_images,) + self.im_shape + (self.channels,)), + dtype=np.float32, ) for idx, signal_im in enumerate(list_signal_tmp): @@ -433,7 +449,7 @@ def __generate_cubes( images, idx, signal_im, background_im ) - return images.astype(np.float16) + return images def __populate_array_with_cubes( self, diff --git a/cellfinder/core/classify/resnet.py b/cellfinder/core/classify/resnet.py index ed313642..07f6253c 100644 --- a/cellfinder/core/classify/resnet.py +++ b/cellfinder/core/classify/resnet.py @@ -1,9 +1,11 @@ from typing import Callable, Dict, List, Literal, Optional, Tuple, Union -from tensorflow import Tensor -from tensorflow.keras import Model -from tensorflow.keras.initializers import Initializer -from tensorflow.keras.layers import ( +from keras import ( + KerasTensor as Tensor, +) +from keras import Model +from keras.initializers import Initializer +from keras.layers import ( Activation, Add, BatchNormalization, @@ -14,7 +16,7 @@ MaxPooling3D, ZeroPadding3D, ) -from tensorflow.keras.optimizers import Adam, Optimizer +from keras.optimizers import Adam, Optimizer ##################################################################### # Define the types of ResNet @@ -113,7 +115,7 @@ def non_residual_block( activation: str = "relu", use_bias: bool = False, bn_epsilon: float = 1e-5, - pooling_padding: str = "same", + pooling_padding: str = "valid", axis: int = 3, ) -> Tensor: """ @@ -131,6 +133,7 @@ def non_residual_block( )(x) x = BatchNormalization(axis=axis, epsilon=bn_epsilon, name="conv1_bn")(x) x = Activation(activation, name="conv1_activation")(x) + x = MaxPooling3D( max_pool_size, strides=strides, diff --git a/cellfinder/core/classify/tools.py b/cellfinder/core/classify/tools.py index 2d5c44b2..3bf48876 100644 --- a/cellfinder/core/classify/tools.py +++ b/cellfinder/core/classify/tools.py @@ -1,9 +1,10 @@ import os -from typing import List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import List, Optional, Tuple, Union +import keras import numpy as np -import tensorflow as tf -from tensorflow.keras import Model +from keras import Model from cellfinder.core import logger from cellfinder.core.classify.resnet import build_model, layer_type @@ -17,8 +18,7 @@ def get_model( inference: bool = False, continue_training: bool = False, ) -> Model: - """ - Returns the correct model based on the arguments passed + """Returns the correct model based on the arguments passed :param existing_model: An existing, trained model. This is returned if it exists :param model_weights: This file is used to set the model weights if it @@ -30,29 +30,31 @@ def get_model( by using the default one :param continue_training: If True, will ensure that a trained model exists. E.g. by using the default one - :return: A tf.keras model + :return: A keras model """ if existing_model is not None or network_depth is None: logger.debug(f"Loading model: {existing_model}") - return tf.keras.models.load_model(existing_model) + return keras.models.load_model(existing_model) else: logger.debug(f"Creating a new instance of model: {network_depth}") model = build_model( - network_depth=network_depth, learning_rate=learning_rate + network_depth=network_depth, + learning_rate=learning_rate, ) if inference or continue_training: logger.debug( - f"Setting model weights according to: {model_weights}" + f"Setting model weights according to: {model_weights}", ) if model_weights is None: - raise IOError("`model_weights` must be provided") + raise OSError("`model_weights` must be provided") model.load_weights(model_weights) return model def make_lists( - tiff_files: Sequence, train: bool = True + tiff_files: Sequence, + train: bool = True, ) -> Union[Tuple[List, List], Tuple[List, List, np.ndarray]]: signal_list = [] background_list = [] diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index 926fe545..23526a94 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -1,23 +1,13 @@ -""" -N.B imports are within functions to prevent tensorflow being imported before -it's warnings are silenced -""" - import os from typing import Callable, List, Optional, Tuple import numpy as np from brainglobe_utils.cells.cells import Cell -from brainglobe_utils.general.logging import suppress_specific_logs from cellfinder.core import logger from cellfinder.core.download.download import model_type from cellfinder.core.train.train_yml import depth_type -tf_suppress_log_messages = [ - "multiprocessing can interact badly with TensorFlow" -] - def main( signal_array: np.ndarray, @@ -58,13 +48,11 @@ def main( Called every time a plane has finished being processed during the detection stage. Called with the plane number that has finished. classify_callback : Callable[int], optional - Called every time tensorflow has finished classifying a point. + Called every time a point has finished being classified. Called with the batch number that has just finished. detect_finished_callback : Callable[list], optional Called after detection is finished with the list of detected points. """ - suppress_tf_logging(tf_suppress_log_messages) - from cellfinder.core.classify import classify from cellfinder.core.detect import detect from cellfinder.core.tools import prep @@ -98,7 +86,7 @@ def main( if not skip_classification: install_path = None model_weights = prep.prep_model_weights( - model_weights, install_path, model, n_free_cpus + model_weights, install_path, model ) if len(points) > 0: logger.info("Running classification") @@ -120,17 +108,4 @@ def main( ) else: logger.info("No candidates, skipping classification") - return points - - -def suppress_tf_logging(tf_suppress_log_messages: List[str]) -> None: - """ - Prevents many lines of logs such as: - "2019-10-24 16:54:41.363978: I tensorflow/stream_executor/platform/default - /dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1" - """ - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - - for message in tf_suppress_log_messages: - suppress_specific_logs("tensorflow", message) diff --git a/cellfinder/core/tools/prep.py b/cellfinder/core/tools/prep.py index a2625311..6609d711 100644 --- a/cellfinder/core/tools/prep.py +++ b/cellfinder/core/tools/prep.py @@ -9,9 +9,7 @@ from typing import Optional from brainglobe_utils.general.config import get_config_obj -from brainglobe_utils.general.system import get_num_processes -import cellfinder.core.tools.tf as tf_tools from cellfinder.core import logger from cellfinder.core.download.download import ( DEFAULT_DOWNLOAD_DIRECTORY, @@ -26,20 +24,13 @@ def prep_model_weights( model_weights: Optional[os.PathLike], install_path: Optional[os.PathLike], model_name: model_type, - n_free_cpus: int, ) -> Path: - n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) - prep_tensorflow(n_processes) + # prepare models (get default weights or provided ones) model_weights = prep_models(model_weights, install_path, model_name) return model_weights -def prep_tensorflow(max_threads: int) -> None: - tf_tools.set_tf_threads(max_threads) - tf_tools.allow_gpu_memory_growth() - - def prep_models( model_weights_path: Optional[os.PathLike], install_path: Optional[os.PathLike], diff --git a/cellfinder/core/tools/tf.py b/cellfinder/core/tools/tf.py deleted file mode 100644 index 778aa78f..00000000 --- a/cellfinder/core/tools/tf.py +++ /dev/null @@ -1,46 +0,0 @@ -import tensorflow as tf - -from cellfinder.core import logger - - -def allow_gpu_memory_growth(): - """ - If a gpu is present, prevent tensorflow from using all the memory straight - away. Allows multiple processes to use the GPU (and avoid occasional - errors on some systems) at the cost of a slight performance penalty. - """ - gpus = tf.config.experimental.list_physical_devices("GPU") - if gpus: - logger.debug("Allowing GPU memory growth") - try: - # Currently, memory growth needs to be the same across GPUs - for gpu in gpus: - tf.config.experimental.set_memory_growth(gpu, True) - logical_gpus = tf.config.experimental.list_logical_devices("GPU") - logger.debug( - f"{len(gpus)} physical GPUs, {len(logical_gpus)} logical GPUs" - ) - except RuntimeError as e: - # Memory growth must be set before GPUs have been initialized - print(e) - else: - logger.debug("No GPUs found, using CPU.") - - -def set_tf_threads(max_threads): - """ - Limit the number of threads that tensorflow uses - :param max_threads: Maximum number of threads to use - :return: - """ - logger.debug( - f"Setting maximum number of threads for tensorflow " - f"to: {max_threads}" - ) - - # If statements are for testing. If tf is initialised, then setting these - # parameters throws an error - if tf.config.threading.get_inter_op_parallelism_threads() != 0: - tf.config.threading.set_inter_op_parallelism_threads(max_threads) - if tf.config.threading.get_intra_op_parallelism_threads() != 0: - tf.config.threading.set_intra_op_parallelism_threads(max_threads) diff --git a/cellfinder/core/train/train_yml.py b/cellfinder/core/train/train_yml.py index bf916b3c..4d6e5bf3 100644 --- a/cellfinder/core/train/train_yml.py +++ b/cellfinder/core/train/train_yml.py @@ -22,7 +22,10 @@ check_positive_float, check_positive_int, ) -from brainglobe_utils.general.system import ensure_directory_exists +from brainglobe_utils.general.system import ( + ensure_directory_exists, + get_num_processes, +) from brainglobe_utils.IO.cells import find_relevant_tiffs from brainglobe_utils.IO.yaml import read_yaml_section from fancylog import fancylog @@ -33,11 +36,6 @@ from cellfinder.core.classify.resnet import layer_type from cellfinder.core.download.download import DEFAULT_DOWNLOAD_DIRECTORY -tf_suppress_log_messages = [ - "sample_weight modes were coerced from", - "multiprocessing can interact badly with TensorFlow", -] - depth_type = Literal["18", "34", "50", "101", "152"] models: Dict[depth_type, layer_type] = { @@ -318,11 +316,7 @@ def run( save_progress=False, epochs=100, ): - from cellfinder.core.main import suppress_tf_logging - - suppress_tf_logging(tf_suppress_log_messages) - - from tensorflow.keras.callbacks import ( + from keras.callbacks import ( CSVLogger, ModelCheckpoint, TensorBoard, @@ -339,7 +333,6 @@ def run( model_weights=model_weights, install_path=install_path, model_name=model, - n_free_cpus=n_free_cpus, ) yaml_contents = parse_yaml(yaml_file) @@ -361,6 +354,7 @@ def run( signal_train, background_train, labels_train = make_lists(tiff_files) + n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) if test_fraction > 0: logger.info("Splitting data into training and validation datasets") ( @@ -387,15 +381,17 @@ def run( labels=labels_test, batch_size=batch_size, train=True, + use_multiprocessing=False, + workers=n_processes, ) # for saving checkpoints - base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}.h5" + base_checkpoint_file_name = "-epoch.{epoch:02d}-loss-{val_loss:.3f}" else: logger.info("No validation data selected.") validation_generator = None - base_checkpoint_file_name = "-epoch.{epoch:02d}.h5" + base_checkpoint_file_name = "-epoch.{epoch:02d}" training_generator = CubeGeneratorFromDisk( signal_train, @@ -405,6 +401,8 @@ def run( shuffle=True, train=True, augment=not no_augment, + use_multiprocessing=False, + workers=n_processes, ) callbacks = [] @@ -421,9 +419,14 @@ def run( if not no_save_checkpoints: if save_weights: - filepath = str(output_dir / ("weight" + base_checkpoint_file_name)) + filepath = str( + output_dir + / ("weight" + base_checkpoint_file_name + ".weights.h5") + ) else: - filepath = str(output_dir / ("model" + base_checkpoint_file_name)) + filepath = str( + output_dir / ("model" + base_checkpoint_file_name + ".keras") + ) checkpoints = ModelCheckpoint( filepath, @@ -432,25 +435,26 @@ def run( callbacks.append(checkpoints) if save_progress: - filepath = str(output_dir / "training.csv") - csv_logger = CSVLogger(filepath) + csv_filepath = str(output_dir / "training.csv") + csv_logger = CSVLogger(csv_filepath) callbacks.append(csv_logger) logger.info("Beginning training.") + # Keras 3.0: `use_multiprocessing` input is set in the + # `training_generator` (False by default) model.fit( training_generator, validation_data=validation_generator, - use_multiprocessing=False, epochs=epochs, callbacks=callbacks, ) if save_weights: logger.info("Saving model weights") - model.save_weights(str(output_dir / "model_weights.h5")) + model.save_weights(output_dir / "model.weights.h5") else: logger.info("Saving model") - model.save(output_dir / "model.h5") + model.save(output_dir / "model.keras") logger.info( "Finished training, " "Total time taken: %s", diff --git a/pyproject.toml b/pyproject.toml index fa717476..d31bda7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,9 +29,8 @@ dependencies = [ "numpy", "scikit-image", "scikit-learn", - # See https://github.com/brainglobe/cellfinder-core/issues/103 for < 2.12.0 pin - "tensorflow-macos>=2.5.0,<2.12.0; platform_system=='Darwin' and platform_machine=='arm64'", - "tensorflow>=2.5.0,<2.12.0; platform_system!='Darwin' or platform_machine!='arm64'", + "keras>=3.0.0", + "torch>=2.1.0", "tifffile", "tqdm", ] @@ -79,7 +78,7 @@ requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] build-backend = 'setuptools.build_meta' [tool.black] -target-version = ['py39', 'py310'] +target-version = ['py39', 'py310','py311'] skip-string-normalization = false line-length = 79 @@ -111,13 +110,14 @@ markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] legacy_tox_ini = """ # For more information about tox, see https://tox.readthedocs.io/en/latest/ [tox] -envlist = py{39,310} +envlist = py{39,310,311} isolated_build = true [gh-actions] python = 3.9: py39 3.10: py310 + 3.11: py311 [testenv] commands = python -m pytest -v --color=yes --cov=cellfinder --cov-report=xml @@ -132,6 +132,8 @@ deps = pytest-qt extras = napari +setenv = + KERAS_BACKEND = torch passenv = NUMBA_DISABLE_JIT CI diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 29e465ad..e97b878f 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -1,8 +1,10 @@ import os from typing import Tuple +import keras.src.backend.common.global_state import numpy as np import pytest +import torch.backends.mps from skimage.filters import gaussian from cellfinder.core.download.download import ( @@ -11,6 +13,22 @@ ) +@pytest.fixture(scope="session", autouse=True) +def set_device_arm_macos_ci(): + """ + Ensure that the device is set to CPU when running on arm based macOS + GitHub runners. This is to avoid the following error: + https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773/5 + """ + if ( + os.getenv("GITHUB_ACTIONS") == "true" + and torch.backends.mps.is_available() + ): + keras.src.backend.common.global_state.set_global_attribute( + "torch_device", "cpu" + ) + + @pytest.fixture(scope="session") def no_free_cpus() -> int: """ diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index 7d90a007..fc7bf2f3 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -80,7 +80,10 @@ def test_detection_full(signal_array, background_array, free_cpus, request): def test_detection_small_planes( - signal_array, background_array, no_free_cpus, mocker + signal_array, + background_array, + no_free_cpus, + mocker, ): # Check that processing works when number of planes < number of processes nproc = get_num_processes(no_free_cpus) diff --git a/tests/core/test_integration/test_train.py b/tests/core/test_integration/test_train.py index 0d4f5c38..8cfc8c31 100644 --- a/tests/core/test_integration/test_train.py +++ b/tests/core/test_integration/test_train.py @@ -35,5 +35,5 @@ def test_train(tmpdir): sys.argv = train_args train_run() - model_file = os.path.join(tmpdir, "model.h5") + model_file = os.path.join(tmpdir, "model.keras") assert os.path.exists(model_file)