diff --git a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py index afcc4b6dad..ddd010b89a 100644 --- a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py +++ b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py @@ -9,6 +9,7 @@ import re from keras import Variable +from keras import backend as K from keras import ops as Kops from keras import optimizers as Kopt from keras.models import Model @@ -16,11 +17,6 @@ import n3fit.backends.keras_backend.operations as op -# Starting with TF 2.16, a memory leak in TF https://github.com/tensorflow/tensorflow/issues/64170 -# makes jit compilation unusable in GPU. -# Before TF 2.16 it was set to `False` by default. From 2.16 onwards, it is set to `True` -JIT_COMPILE = False - # Define in this dictionary new optimizers as well as the arguments they accept # (with default values if needed be) optimizers = { @@ -296,13 +292,16 @@ def compile( # If given target output is None, target_output is unnecesary, save just a zero per output if target_output is None: - self.target_tensors = [op.numpy_to_tensor(np.zeros((1, 1))) for i in self.output_shape] + self.target_tensors = [op.numpy_to_tensor(np.zeros((1, 1))) for _ in self.output_shape] else: if not isinstance(target_output, list): target_output = [target_output] self.target_tensors = target_output - super().compile(optimizer=opt, loss=loss, jit_compile=JIT_COMPILE) + # For debug purposes it may be interesting to set in the compile call + # jit_compile = False + # run_eager = True + super().compile(optimizer=opt, loss=loss) def set_masks_to(self, names, val=0.0): """Set all mask value to the selected value diff --git a/n3fit/src/n3fit/backends/keras_backend/internal_state.py b/n3fit/src/n3fit/backends/keras_backend/internal_state.py index 4235a73473..23d9ed3819 100644 --- a/n3fit/src/n3fit/backends/keras_backend/internal_state.py +++ b/n3fit/src/n3fit/backends/keras_backend/internal_state.py @@ -21,7 +21,7 @@ log = logging.getLogger(__name__) # Prepare Keras-backend dependent functions -if K.backend() == "torch": +if K.backend() in ("torch", "jax"): def set_eager(flag=True): """Pytorch is eager by default""" diff --git a/n3fit/src/n3fit/backends/keras_backend/operations.py b/n3fit/src/n3fit/backends/keras_backend/operations.py index ae65e1f1de..3f5f9c5736 100644 --- a/n3fit/src/n3fit/backends/keras_backend/operations.py +++ b/n3fit/src/n3fit/backends/keras_backend/operations.py @@ -23,8 +23,6 @@ equally operations are automatically converted to layers when used as such. """ -from typing import Optional - from keras import backend as K from keras import ops as Kops from keras.layers import ELU, Input @@ -37,9 +35,12 @@ # Backend dependent functions and operations if K.backend() == "torch": - tensor_to_numpy_or_python = lambda x: x.detach().numpy() + tensor_to_numpy_or_python = lambda x: x.detach().cpu().numpy() + decorator_compiler = lambda f: f +elif K.backend() == "jax": + tensor_to_numpy_or_python = lambda x: np.array(x.block_until_ready()) decorator_compiler = lambda f: f -else: +elif K.backend() == "tensorflow": tensor_to_numpy_or_python = lambda x: x.numpy() lambda ret: {k: i.numpy() for k, i in ret.items()} import tensorflow as tf diff --git a/n3fit/src/n3fit/layers/observable.py b/n3fit/src/n3fit/layers/observable.py index 29c5754ecc..14a7c8cd15 100644 --- a/n3fit/src/n3fit/layers/observable.py +++ b/n3fit/src/n3fit/layers/observable.py @@ -225,7 +225,7 @@ def compute_float_mask(bool_mask): """ # Create a tensor with the shape (**bool_mask.shape, num_active_flavours) masked_to_full = [] - for idx in np.argwhere(np.array(bool_mask)): + for idx in np.argwhere(op.tensor_to_numpy_or_python(bool_mask)): temp_matrix = np.zeros(bool_mask.shape) temp_matrix[tuple(idx)] = 1 masked_to_full.append(temp_matrix)