Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cellfinder with Keras 3.0 and torch backend #380

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ jobs:
run: |
python -m pip install --upgrade pip wheel
# Install cellfinder from the latest SHA on this branch
python -m pip install git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA
# Install tensorflow as keras' default backend
python -m pip install "tf-nightly==2.16.0.dev20240101"
python -m pip install "cellfinder[tf] @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA"
# Install checked out copy of brainglobe-workflows
python -m pip install .[dev]

Expand Down
7 changes: 6 additions & 1 deletion cellfinder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
if not os.getenv("KERAS_BACKEND"):
os.environ["KERAS_BACKEND"] = "tensorflow"
warnings.warn(
"Keras backend not configured, automatically set to Tensorflow"
"Keras backend not configured, automatically set to tensorflow"
)

# Check backend is installed
Expand All @@ -48,6 +48,11 @@
"Keras backend must be one of 'tensorflow', 'jax', or 'torch'"
)

# # Change image data format for Keras --- better somewhere else!
# import keras
# if keras.config.backend() == "torch":
# keras.config.set_image_data_format('channels_first')


__author__ = "Adam Tyson, Christian Niedworok, Charly Rousseau"
__license__ = "BSD-3-Clause"
6 changes: 6 additions & 0 deletions cellfinder/core/tools/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DEFAULT_INSTALL_PATH = home / ".cellfinder"


# should this be called prep_models, and the other one prep_model_weights?
def prep_model_weights(
model_weights: Optional[os.PathLike],
install_path: Optional[os.PathLike],
Expand All @@ -33,6 +34,11 @@ def prep_model_weights(
n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
prep_tensorflow(n_processes)

# if torch backend: change image data format to channels first
# (this expects batch size last)
elif keras.config.backend() == "torch":
keras.config.set_image_data_format("channels_first")

# prepare models (get default weights or provided ones)
model_weights = prep_models(model_weights, install_path, model_name)

Expand Down
20 changes: 17 additions & 3 deletions cellfinder/core/train/train_yml.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,27 +337,31 @@ def run(
start_time = datetime.now()

ensure_directory_exists(output_dir)

model_weights = prep_model_weights(
install_path, model_weights, model, n_free_cpus
)
model_weights, install_path, model, n_free_cpus
) # path

yaml_contents = parse_yaml(yaml_file)

tiff_files = get_tiff_files(yaml_contents)

logger.info(
f"Found {sum(len(imlist) for imlist in tiff_files)} images "
f"from {len(yaml_contents)} datasets "
f"in {len(yaml_file)} yaml files"
)

### Get model ready
model = get_model(
existing_model=trained_model,
model_weights=model_weights,
network_depth=models[network_depth],
learning_rate=learning_rate,
continue_training=continue_training,
)
) # keras.src.models.functional.Functional

### Prep data
signal_train, background_train, labels_train = make_lists(tiff_files)

if test_fraction > 0:
Expand Down Expand Up @@ -397,6 +401,7 @@ def run(
validation_generator = None
base_checkpoint_file_name = "-epoch.{epoch:02d}"

### Generate "dataloader"
training_generator = CubeGeneratorFromDisk(
signal_train,
background_train,
Expand All @@ -407,6 +412,11 @@ def run(
augment=not no_augment,
use_multiprocessing=False,
)

### Prepare callbacks
# - tensorboard
# - checkpoint saving
# - csv logger
callbacks = []

if tensorboard:
Expand Down Expand Up @@ -442,6 +452,7 @@ def run(
csv_logger = CSVLogger(csv_filepath)
callbacks.append(csv_logger)

### Begin training
logger.info("Beginning training.")
# Keras 3.0: `use_multiprocessing` input is set in the
# `training_generator` (False by default)
Expand All @@ -451,6 +462,9 @@ def run(
epochs=epochs,
callbacks=callbacks,
)
# model.built = True,
# model.compiled=True, and
# compile_metrics_unbuilt ---> I get error

if save_weights:
logger.info("Saving model weights")
Expand Down
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
legacy_tox_ini = """
# For more information about tox, see https://tox.readthedocs.io/en/latest/
[tox]
envlist = py{39,310}-{tf,jax}
envlist = py{39,310}-{tf,jax,torch}
isolated_build = true

[gh-actions]
python =
3.9: py39-{tf,jax} # On GA python=3.9 job, run tox with the tf and jax environments
3.10: py310-{tf,jax} # On GA python=3.10 job, run tox with the tf and jax environments
3.9: py39-{tf,jax,torch} # On GA python=3.9 job, run tox with the tf and jax environments
3.10: py310-{tf,jax,torch} # On GA python=3.10 job, run tox with the tf and jax environments

[testenv]
commands = python -m pytest -v --color=yes
Expand All @@ -143,9 +143,11 @@ extras =
napari
tf: tf
jax: jax
torch: torch
setenv =
tf: KERAS_BACKEND = tensorflow
jax: KERAS_BACKEND = jax
torch: KERAS_BACKEND = torch
passenv =
NUMBA_DISABLE_JIT
CI
Expand Down
Loading