diff --git a/.github/workflows/pytorch_test.yml b/.github/workflows/pytorch_test.yml new file mode 100644 index 0000000000..1621159b4c --- /dev/null +++ b/.github/workflows/pytorch_test.yml @@ -0,0 +1,25 @@ +name: Test pytorch + +on: [push] + +jobs: + run_pytorch: + runs-on: ubuntu-latest + env: + KERAS_BACKEND: torch + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install nnpdf without LHAPDF + shell: bash -l {0} + run: | + pip install .[nolha,torch] + # Since there is no LHAPDF in the system, initialize the folder and download pdfsets.index + lhapdf-management update --init + - name: Test we can run one runcard + shell: bash -l {0} + run: | + cd n3fit/runcards/examples + n3fit Basic_runcard.yml 4 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e719ff4423..315f95e674 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-14] - python-version: ["3.10"] # We need an older python version to avoid conflict with the pymongo pin + python-version: ["3.12"] fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index e364904f7e..6f7432a8b8 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -19,8 +19,9 @@ requirements: - pip run: - python >=3.9,<3.13 - - tensorflow >=2.10,<2.17 # 2.17 works ok but the conda-forge package for macos doesn't - - psutil + - tensorflow >=2.17 + - keras >=3.1 + - psutil # to ensure n3fit affinity is with the right processors - hyperopt - mongodb - pymongo <4 diff --git a/doc/sphinx/source/get-started/nnpdfmodules.rst b/doc/sphinx/source/get-started/nnpdfmodules.rst index 054fefaa53..0643b43570 100644 --- a/doc/sphinx/source/get-started/nnpdfmodules.rst +++ b/doc/sphinx/source/get-started/nnpdfmodules.rst @@ -14,7 +14,7 @@ for an NNPDF fit is displayed in the figure below. The :ref:`n3fit ` fitting code -------------------------------------------------------------------------------- This module implements the core fitting methodology as implemented through -the ``TensorFlow`` framework. The ``n3fit`` library allows +the ``Keras`` framework. The ``n3fit`` library allows for a flexible specification of the neural network model adopted to parametrise the PDFs, whose settings can be selected automatically via the built-in :ref:`hyperoptimization algorithm `. These diff --git a/doc/sphinx/source/n3fit/index.rst b/doc/sphinx/source/n3fit/index.rst index 5d5705ba97..3225fa529a 100644 --- a/doc/sphinx/source/n3fit/index.rst +++ b/doc/sphinx/source/n3fit/index.rst @@ -6,8 +6,7 @@ Fitting code: ``n3fit`` - ``n3fit`` is the next generation fitting code for NNPDF developed by the N3PDF team :cite:p:`Carrazza:2019mzf` - ``n3fit`` is responsible for fitting PDFs from NNPDF4.0 onwards. -- The code is implemented in python using `Tensorflow `_ - and `Keras `_. +- The code is implemented in python using `Keras `_ and can run with `Tensorflow `_ (default) or `pytorch `_ (with the environment variable ``KERAS_BACKEND=torch``). - The sections below are an overview of the ``n3fit`` design. diff --git a/doc/sphinx/source/n3fit/methodology.rst b/doc/sphinx/source/n3fit/methodology.rst index 8380d5526d..f460af5967 100644 --- a/doc/sphinx/source/n3fit/methodology.rst +++ b/doc/sphinx/source/n3fit/methodology.rst @@ -8,8 +8,8 @@ different in comparison to the latest NNPDF (i.e. `NNPDF3.1 `_. .. note:: @@ -90,7 +90,7 @@ random numbers used in training-validation, ``nnseed`` for the neural network in Neural network architecture --------------------------- -The main advantage of using a modern deep learning backend such as Keras/Tensorflow consists in the +The main advantage of using a modern deep learning backend such as Keras consists in the possibility to change the neural network architecture quickly as the developer is not forced to fine tune the code in order to achieve efficient memory management and PDF convolution performance. @@ -132,41 +132,36 @@ See the `Keras documentation `_. +It is possible to inspect the ``n3fit`` code using `TensorBoard `_ when running with the tensorflow backend. In order to enable the TensorBoard callback in ``n3fit`` it is enough with adding the following options in the runcard: @@ -333,7 +333,7 @@ top-level option: parallel_models: true Note that currently, in order to run with parallel models, one has to set ``savepseudodata: false`` -in the ``fitting`` section of the runcard. Once this is done, the user can run ``n3fit`` with a +in the ``fitting`` section of the runcard. Once this is done, the user can run ``n3fit`` with a replica range to be parallelized (in this case from replica 1 to replica 4). .. code-block:: bash @@ -346,8 +346,8 @@ should run by setting the environment variable ``CUDA_VISIBLE_DEVICES`` to the right index (usually ``0, 1, 2``) or leaving it explicitly empty to avoid running on GPU: ``export CUDA_VISIBLE_DEVICES=""`` -Note that in order to run the replicas in parallel using the GPUs of an Apple Silicon computer (like M1 Mac), it is necessary to also install -the following packages: +Note that in order to run the replicas in parallel using the GPUs of an Apple Silicon computer (like M1 Mac), it is necessary to also install +extra packages. At the timing of writing this worked with ``tensorflow`` 2.13. .. code-block:: bash diff --git a/doc/sphinx/source/tutorials/run-fit.rst b/doc/sphinx/source/tutorials/run-fit.rst index 4293563fb2..b3d4437260 100644 --- a/doc/sphinx/source/tutorials/run-fit.rst +++ b/doc/sphinx/source/tutorials/run-fit.rst @@ -51,7 +51,7 @@ example of the ``parameter`` dictionary that defines the Machine Learning framew dropout: 0.0 ... -The runcard system is designed such that the user can utilize the program +The runcard system is designed such that the user can utilize the program without having to tinker with the codebase. One can simply modify the options in ``parameters`` to specify the desired architecture of the Neural Network as well as the settings for the optimization algorithm. @@ -164,7 +164,7 @@ folder, which contains a number of files: - ``runcard.exportgrid``: a file containing the PDF grid. - ``runcard.json``: Includes information about the fit (metadata, parameters, times) in json format. -.. note:: +.. note:: The reported χ² refers always to the actual χ², i.e., without positivity loss or other penalty terms. @@ -184,25 +184,26 @@ After obtaining the fit you can proceed with the fit upload and analisis by: Performance of the fit ---------------------- -The ``n3fit`` framework is currently based on `Tensorflow `_ and as such, to -first approximation, anything that makes Tensorflow faster will also make ``n3fit`` faster. - -.. note:: - - Tensorflow only supports the installation via pip. Note, however, that the TensorFlow - pip package has been known to break third party packages. Install it at your own risk. - Only the conda tensorflow-eigen package is tested by our CI systems. - -When you install the nnpdf conda package, you get the -`tensorflow-eigen `_ package, -which is not the default. This is due to a memory explosion found in some of +The ``n3fit`` framework is currently based on `Keras `_ +and it is tested to run with the `Tensorflow `_ +and `pytorch `_ backends. +This also means that anything that make any of these packages faster will also +make ``n3fit`` faster. +Note that at the time of writing, ``TensorFlow`` is approximately 4 times faster than ``pytorch``. + +The default backend for ``keras`` is ``tensorflow``. +In order to change the backend, the environment variable ``KERAS_BACKENDD`` need to be set (e.g., ``KERAS_BACKEND=torch``). + +The best results are obtained with ``tensorflow[and-cuda]`` installed from pip. +When you install the nnpdf conda package, you get the +`tensorflow-eigen `_ package, +which is not the default. This is due to a memory explosion found in some of the conda mkl builds. -If you want to disable MKL without installing ``tensorflow-eigen`` you can always +If you want to disable MKL without installing ``tensorflow-eigen`` you can always set the environment variable ``TF_DISABLE_MKL=1`` before running ``n3fit``. When running ``n3fit`` all versions of the package show similar performance. - When using the MKL version of tensorflow you gain more control of the way Tensorflow will use the multithreading capabilities of the machine by using the following environment variables: @@ -214,7 +215,7 @@ the multithreading capabilities of the machine by using the following environmen These are the best values found for ``n3fit`` when using the mkl version of Tensorflow from conda and were found for TF 2.1 as the default values were suboptimal. For a more detailed explanation on the effects of ``KMP_AFFINITY`` on the performance of -the code please see +the code please see `here `_. By default, ``n3fit`` will try to use as many cores as possible, but this behaviour can be overriden diff --git a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py index 2399c16a64..d5f1c8c5bd 100644 --- a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py +++ b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py @@ -8,27 +8,12 @@ from pathlib import Path import re +from keras import Variable +from keras import optimizers as Kopt +from keras.models import Model import numpy as np -import tensorflow as tf -from tensorflow.keras import optimizers as Kopt -from tensorflow.keras.models import Model -from tensorflow.python.keras.utils import tf_utils # pylint: disable=no-name-in-module - -import n3fit.backends.keras_backend.operations as op - -# We need a function to transform tensors to numpy/python primitives -# which is not part of the official TF interface and can change with the version -if hasattr(tf_utils, "to_numpy_or_python_type"): - _to_numpy_or_python_type = tf_utils.to_numpy_or_python_type -elif hasattr(tf_utils, "sync_to_numpy_or_python_type"): # from TF 2.5 - _to_numpy_or_python_type = tf_utils.sync_to_numpy_or_python_type -else: # in case of disaster - _to_numpy_or_python_type = lambda ret: {k: i.numpy() for k, i in ret.items()} - -# Starting with TF 2.16, a memory leak in TF https://github.com/tensorflow/tensorflow/issues/64170 -# makes jit compilation unusable in GPU. -# Before TF 2.16 it was set to `False` by default. From 2.16 onwards, it is set to `True` -JIT_COMPILE = False + +from . import operations as ops # Define in this dictionary new optimizers as well as the arguments they accept # (with default values if needed be) @@ -55,7 +40,7 @@ def _default_loss(y_true, y_pred): # pylint: disable=unused-argument """Default loss to be used when the model is compiled with loss = Null (for instance if the prediction of the model is already the loss""" - return op.sum(y_pred) + return ops.sum(y_pred) class MetaModel(Model): @@ -108,7 +93,7 @@ def __init__(self, input_tensors, output_tensors, scaler=None, input_values=None if k in input_values: x_in[k] = input_values[k] elif hasattr(v, "tensor_content"): - x_in[k] = op.numpy_to_tensor(v.tensor_content) + x_in[k] = ops.numpy_to_tensor(v.tensor_content) else: self.required_slots.add(k) super().__init__(input_tensors, output_tensors, **kwargs) @@ -121,7 +106,6 @@ def __init__(self, input_tensors, output_tensors, scaler=None, input_values=None self.compute_losses_function = None self._scaler = scaler - @tf.autograph.experimental.do_not_convert def _parse_input(self, extra_input=None): """Returns the input data the model was compiled with. Introduces the extra_input in the places asigned to the placeholders. @@ -173,8 +157,8 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs): steps_per_epoch = self._determine_steps_per_epoch(epochs) for k, v in x_params.items(): - x_params[k] = tf.repeat(v, steps_per_epoch, axis=0) - y = [tf.repeat(yi, steps_per_epoch, axis=0) for yi in y] + x_params[k] = ops.repeat(v, steps_per_epoch, axis=0) + y = [ops.repeat(yi, steps_per_epoch, axis=0) for yi in y] history = super().fit( x=x_params, y=y, epochs=epochs // steps_per_epoch, batch_size=1, **kwargs @@ -228,13 +212,13 @@ def compute_losses(self): inputs[k] = v[:1] # Compile a evaluation function - @tf.function + @ops.decorator_compiler def losses_fun(): predictions = self(inputs) # If we only have one dataset the output changes if len(out_names) == 2: predictions = [predictions] - total_loss = tf.reduce_sum(predictions, axis=0) + total_loss = ops.sum(predictions, axis=0) ret = [total_loss] + predictions return dict(zip(out_names, ret)) @@ -244,7 +228,7 @@ def losses_fun(): # The output of this function is to be used by python (and numpy) # so we need to convert the tensors - return _to_numpy_or_python_type(ret) + return ops.dict_to_numpy_or_python(ret) def compile( self, @@ -305,13 +289,16 @@ def compile( # If given target output is None, target_output is unnecesary, save just a zero per output if target_output is None: - self.target_tensors = [op.numpy_to_tensor(np.zeros((1, 1))) for i in self.output_shape] + self.target_tensors = [ops.numpy_to_tensor(np.zeros((1, 1))) for _ in self.output_shape] else: if not isinstance(target_output, list): target_output = [target_output] self.target_tensors = target_output - super().compile(optimizer=opt, loss=loss, jit_compile=JIT_COMPILE) + # For debug purposes it may be interesting to set in the compile call + # jit_compile = False + # run_eager = True + super().compile(optimizer=opt, loss=loss) def set_masks_to(self, names, val=0.0): """Set all mask value to the selected value @@ -509,9 +496,9 @@ def get_layer_replica_weights(layer, i_replica: int): """ if is_stacked_single_replicas(layer): weights_ref = layer.get_layer(f"{NN_PREFIX}_{i_replica}").weights - weights = [tf.Variable(w, name=w.name) for w in weights_ref] + weights = [Variable(w, name=w.name) for w in weights_ref] else: - weights = [tf.Variable(w[i_replica : i_replica + 1], name=w.name) for w in layer.weights] + weights = [Variable(w[i_replica : i_replica + 1], name=w.name) for w in layer.weights] return weights diff --git a/n3fit/src/n3fit/backends/keras_backend/base_layers.py b/n3fit/src/n3fit/backends/keras_backend/base_layers.py index 2ed2628293..849cd74175 100644 --- a/n3fit/src/n3fit/backends/keras_backend/base_layers.py +++ b/n3fit/src/n3fit/backends/keras_backend/base_layers.py @@ -17,16 +17,14 @@ The names of the layer and the activation function are the ones to be used in the n3fit runcard. """ -from tensorflow import expand_dims, math, nn -from tensorflow.keras.layers import Dense as KerasDense -from tensorflow.keras.layers import Dropout, Lambda -from tensorflow.keras.layers import Input # pylint: disable=unused-import -from tensorflow.keras.layers import LSTM, Concatenate -from tensorflow.keras.regularizers import l1_l2 +from keras.layers import Dense as KerasDense +from keras.layers import Dropout, Lambda +from keras.layers import Input # pylint: disable=unused-import +from keras.layers import LSTM, Concatenate +from keras.regularizers import l1_l2 +from . import operations as ops from .MetaLayer import MetaLayer -from .operations import concatenate_function - # Custom activation functions def square_activation(x): @@ -38,17 +36,17 @@ def square_singlet(x): """Square the singlet sector Defined as the two first values of the NN""" singlet_squared = x[..., :2] ** 2 - return concatenate_function([singlet_squared, x[..., 2:]], axis=-1) + return ops.concatenate([singlet_squared, x[..., 2:]], axis=-1) def modified_tanh(x): """A non-saturating version of the tanh function""" - return math.abs(x) * nn.tanh(x) + return ops.absolute(x) * ops.tanh(x) def leaky_relu(x): """Computes the Leaky ReLU activation function""" - return nn.leaky_relu(x, alpha=0.2) + return ops.leaky_relu(x, alpha=0.2) custom_activations = { @@ -64,7 +62,7 @@ def LSTM_modified(**kwargs): LSTM asks for a sample X timestep X features kind of thing so we need to reshape the input """ the_lstm = LSTM(**kwargs) - ExpandDim = Lambda(lambda x: expand_dims(x, axis=-1)) + ExpandDim = Lambda(lambda x: ops.expand_dims(x, axis=-1)) def ReshapedLSTM(input_tensor): if len(input_tensor.shape) == 2: diff --git a/n3fit/src/n3fit/backends/keras_backend/callbacks.py b/n3fit/src/n3fit/backends/keras_backend/callbacks.py index 911f069e5c..f3627e9e3b 100644 --- a/n3fit/src/n3fit/backends/keras_backend/callbacks.py +++ b/n3fit/src/n3fit/backends/keras_backend/callbacks.py @@ -15,9 +15,10 @@ import logging from time import time +from keras.callbacks import Callback, TensorBoard import numpy as np -import tensorflow as tf -from tensorflow.keras.callbacks import Callback, TensorBoard + +from .operations import decorator_compiler log = logging.getLogger(__name__) @@ -171,7 +172,7 @@ def on_train_begin(self, logs=None): layer = self.model.get_layer(layer_name) self.updateable_weights.append(layer.weights) - @tf.function + @decorator_compiler def _update_weights(self): """Update all the weight with the corresponding multipliers Wrapped with tf.function to compensate the for loops as both weights variables @@ -194,7 +195,8 @@ def gen_tensorboard_callback(log_dir, profiling=False, histogram_freq=0): If the profiling flag is set to True, it will also attempt to save profiling data. - Note the usage of this callback can hurt performance. + Note the usage of this callback can hurt performance + At the moment can only be used with TensorFlow: https://github.com/keras-team/keras/issues/19121 Parameters ---------- diff --git a/n3fit/src/n3fit/backends/keras_backend/constraints.py b/n3fit/src/n3fit/backends/keras_backend/constraints.py index e943c1fcb6..7ac874e0d8 100644 --- a/n3fit/src/n3fit/backends/keras_backend/constraints.py +++ b/n3fit/src/n3fit/backends/keras_backend/constraints.py @@ -2,9 +2,10 @@ Implementations of weight constraints for initializers """ -import tensorflow as tf -from tensorflow.keras import backend as K -from tensorflow.keras.constraints import MinMaxNorm +from keras import backend as K +from keras.constraints import MinMaxNorm + +from . import operations as ops class MinMaxWeight(MinMaxNorm): @@ -17,8 +18,8 @@ def __init__(self, min_value, max_value, **kwargs): super().__init__(min_value=min_value, max_value=max_value, axis=1, **kwargs) def __call__(self, w): - norms = K.sum(w, axis=self.axis, keepdims=True) + norms = ops.sum(w, axis=self.axis, keepdims=True) desired = ( - self.rate * K.clip(norms, self.min_value, self.max_value) + (1 - self.rate) * norms + self.rate * ops.clip(norms, self.min_value, self.max_value) + (1 - self.rate) * norms ) return w * desired / (K.epsilon() + norms) diff --git a/n3fit/src/n3fit/backends/keras_backend/internal_state.py b/n3fit/src/n3fit/backends/keras_backend/internal_state.py index e818716940..3b7be3f7ed 100644 --- a/n3fit/src/n3fit/backends/keras_backend/internal_state.py +++ b/n3fit/src/n3fit/backends/keras_backend/internal_state.py @@ -1,6 +1,7 @@ """ Library of functions that modify the internal state of Keras/Tensorflow """ + import os import psutil @@ -13,20 +14,51 @@ import logging import random as rn +import keras +from keras import backend as K import numpy as np -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import backend as K log = logging.getLogger(__name__) +# Prepare Keras-backend dependent functions +if K.backend() in ("torch", "jax"): + import torch -def set_eager(flag=True): - """Set eager mode on or off - for a very slow but fine grained debugging call this function as early as possible - ideally after the first tf import - """ - tf.config.run_functions_eagerly(flag) + def set_eager(flag=True): + """Pytorch is eager by default""" + pass + + def set_threading(threads, core): + """Not implemented""" + log.info("Setting max number of threads to: %d", threads) + torch.set_num_threads(threads) + +elif K.backend() == "tensorflow": + import tensorflow as tf + + def set_eager(flag=True): + """Set eager mode on or off + for a very slow but fine grained debugging call this function as early as possible + ideally after the first tf import + """ + tf.config.run_functions_eagerly(flag) + + def set_threading(threads, cores): + """Set the Tensorflow inter and intra parallelism options""" + log.info("Setting the number of cores to: %d", cores) + try: + tf.config.threading.set_inter_op_parallelism_threads(threads) + tf.config.threading.set_intra_op_parallelism_threads(cores) + except RuntimeError: + # If tensorflow has already been initiated, the previous calls might fail. + # This may happen for instance if pdfflow is being used + log.warning( + "Could not set tensorflow parallelism settings from n3fit, maybe tensorflow is already initialized by a third program" + ) + +else: + # Keras should've failed by now, if it doesn't it could be a new backend that works ootb? + log.warning(f"Backend {K.backend()} not recognized. You are entering uncharted territory") def set_number_of_cores(max_cores=None, max_threads=None): @@ -62,16 +94,7 @@ def set_number_of_cores(max_cores=None, max_threads=None): if max_threads is not None: threads = min(max_threads, threads) - log.info("Setting the number of cores to: %d", cores) - try: - tf.config.threading.set_inter_op_parallelism_threads(threads) - tf.config.threading.set_intra_op_parallelism_threads(cores) - except RuntimeError: - # If pdfflow is being used, tensorflow will already be initialized by pdfflow - # maybe it would be good to drop completely pdfflow before starting the fit? (TODO ?) - log.warning( - "Could not set tensorflow parallelism settings from n3fit, maybe has already been initialized?" - ) + set_threading(threads, cores) def clear_backend_state(): @@ -129,7 +152,7 @@ def set_initial_state(debug=False, external_seed=None, max_cores=None, double_pr clear_backend_state() if double_precision: - tf.keras.backend.set_floatx('float64') + K.set_floatx('float64') # Set the number of cores depending on the user choice of max_cores # if debug mode and no number of cores set by the user, set to 1 @@ -142,7 +165,11 @@ def set_initial_state(debug=False, external_seed=None, max_cores=None, double_pr # Once again, if in debug mode or external_seed set, set also the TF seed if debug or external_seed: - tf.random.set_seed(use_seed) + if K.backend() == "tensorflow": + # if backend is tensorflow, keep the old seeding + tf.random.set_seed(use_seed) + else: + keras.utils.set_random_seed(use_seed) def get_physical_gpus(): diff --git a/n3fit/src/n3fit/backends/keras_backend/operations.py b/n3fit/src/n3fit/backends/keras_backend/operations.py index b6ad0e010e..f123e450e3 100644 --- a/n3fit/src/n3fit/backends/keras_backend/operations.py +++ b/n3fit/src/n3fit/backends/keras_backend/operations.py @@ -6,8 +6,6 @@ This includes an implementation of the NNPDF operations on fktable in the keras language (with the mapping ``c_to_py_fun``) into Keras ``Lambda`` layers. - Tensor operations are compiled through the @tf.function decorator for optimization - The rest of the operations in this module are divided into four categories: numpy to tensor: Operations that take a numpy array and return a tensorflow tensor @@ -18,37 +16,66 @@ layer generation: Instanciate a layer to be applied by the calling function - Some of these are just aliases to the backend (tensorflow or Keras) operations + Most of the operations in this module are just aliases to the backend + (Keras in this case) so that, when implementing new backends, it is clear + which operations may need to be overwritten. + For a few selected operations, a more complicated wrapper to e.g., make + them into layers or apply some default, is included. + Note that tensor operations can also be applied to layers as the output of a layer is a tensor equally operations are automatically converted to layers when used as such. """ -from typing import Optional - +from keras import backend as K +from keras import ops as Kops +from keras.layers import ELU, Input +from keras.layers import Lambda as keras_Lambda import numpy as np -import numpy.typing as npt -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import backend as K -from tensorflow.keras.layers import Input -from tensorflow.keras.layers import Lambda as keras_Lambda -from tensorflow.keras.layers import multiply as keras_multiply -from tensorflow.keras.layers import subtract as keras_subtract from validphys.convolution import OP -# Select a concatenate function depending on the tensorflow version -try: - # For tensorflow >= 2.16, Keras >= 3 - concatenate_function = keras.ops.concatenate -except AttributeError: - # keras.ops was introduced in keras 3 - concatenate_function = tf.concat - - -def evaluate(tensor): - """Evaluate input tensor using the backend""" - return K.eval(tensor) +# The following operations are either loaded directly from keras and exposed here +# or the name is change slightly (usually for historical or collision reasons, +# e.g., our logs are always logs or we were using the tf version in the past) + +# isort: off +from keras.ops import ( + absolute, + clip, + einsum, + expand_dims, + leaky_relu, + reshape, + repeat, + split, + sum, + tanh, + transpose, +) +from keras.ops import log as op_log +from keras.ops import power as pow +from keras.ops import take as gather +from keras.ops import tensordot as tensor_product +from keras.layers import multiply as op_multiply +from keras.layers import subtract as op_subtract + +# isort: on + +# Backend dependent functions and operations +if K.backend() == "torch": + tensor_to_numpy_or_python = lambda x: x.detach().cpu().numpy() + decorator_compiler = lambda f: f +elif K.backend() == "jax": + tensor_to_numpy_or_python = lambda x: np.array(x.block_until_ready()) + decorator_compiler = lambda f: f +elif K.backend() == "tensorflow": + tensor_to_numpy_or_python = lambda x: x.numpy() + lambda ret: {k: i.numpy() for k, i in ret.items()} + import tensorflow as tf + + decorator_compiler = tf.function + +dict_to_numpy_or_python = lambda ret: {k: tensor_to_numpy_or_python(i) for k, i in ret.items()} def as_layer(operation, op_args=None, op_kwargs=None, **kwargs): @@ -101,7 +128,6 @@ def c_to_py_fun(op_name, name="dataset"): except KeyError as e: raise ValueError(f"Operation {op_name} not recognised") from e - @tf.function def operate_on_tensors(tensor_list): return operation(*tensor_list) @@ -113,19 +139,19 @@ def numpy_to_tensor(ival, **kwargs): """ Make the input into a tensor """ - if kwargs.get("dtype", None) is not bool: - kwargs["dtype"] = tf.keras.backend.floatx() - return K.constant(ival, **kwargs) + if (dtype := kwargs.get("dtype", None)) is not bool: + dtype = K.floatx() + return Kops.cast(ival, dtype) # f(x: tensor) -> y: tensor def batchit(x, batch_dimension=0, **kwarg): """Add a batch dimension to tensor x""" - return tf.expand_dims(x, batch_dimension, **kwarg) + return Kops.expand_dims(x, batch_dimension, **kwarg) # layer generation -def numpy_to_input(numpy_array: npt.NDArray, name: Optional[str] = None): +def numpy_to_input(numpy_array, name=None): """ Takes a numpy array and generates an Input layer with the same shape, but with a batch dimension (of size 1) added. @@ -146,33 +172,6 @@ def numpy_to_input(numpy_array: npt.NDArray, name: Optional[str] = None): return input_layer -# -# Layer to Layer operations -# -def op_multiply(o_list, **kwargs): - """ - Receives a list of layers of the same output size and multiply them element-wise - """ - return keras_multiply(o_list, **kwargs) - - -def op_multiply_dim(o_list, **kwargs): - """ - Bypass in order to multiply two layers with different output dimension - for instance: (10000 x 14) * (14) - as the normal keras multiply don't accept it (but somewhow it does accept it doing it like this) - """ - if len(o_list) != 2: - raise ValueError( - "The number of observables is incorrect, operations.py:op_multiply_dim, expected 2, received {}".format( - len(o_list) - ) - ) - - layer_op = as_layer(lambda inputs: inputs[0] * inputs[1]) - return layer_op(o_list) - - def op_gather_keep_dims(tensor, indices, axis=0, **kwargs): """A convoluted way of providing ``x[:, indices, :]`` @@ -183,82 +182,23 @@ def op_gather_keep_dims(tensor, indices, axis=0, **kwargs): indices = tensor.shape[axis] - 1 def tmp(x): - y = tf.gather(x, indices, axis=axis, **kwargs) - return tf.expand_dims(y, axis=axis) + y = gather(x, indices, axis=axis) + return Kops.expand_dims(y, axis=axis) layer_op = as_layer(tmp) return layer_op(tensor) -def gather(*args, **kwargs): - """ - Gather elements from a tensor along an axis - """ - return tf.gather(*args, **kwargs) - - -# -# Tensor operations -# f(x: tensor[s]) -> y: tensor -# - - -# Generation operations -# generate tensors of given shape/content -@tf.function -def tensor_ones_like(*args, **kwargs): - """ - Generates a tensor of ones of the same shape as the input tensor - See full `docs `_ - """ - return K.ones_like(*args, **kwargs) - - -# Property operations -# modify properties of the tensor like the shape or elements it has -@tf.function def flatten(x): """Flatten tensor x""" - return tf.reshape(x, (-1,)) - - -@tf.function -def reshape(x, shape): - """reshape tensor x""" - return tf.reshape(x, shape) - - -@tf.function -def boolean_mask(*args, target_shape=None, **kwargs): - """ - Applies a boolean mask to a tensor - - Relevant parameters: (tensor, mask, axis=None) - see full `docs `_. - - tensorflow's masking concatenates the masked dimensions, it is possible to - provide a `target_shape` to reshape the output to the desired shape - """ - ret = tf.boolean_mask(*args, **kwargs) - if target_shape is not None: - ret = reshape(ret, target_shape) - return ret - - -@tf.function -def transpose(tensor, **kwargs): - """ - Transpose a layer, - see full `docs `_ - """ - return K.transpose(tensor, **kwargs) + return reshape(x, (-1,)) def stack(tensor_list, axis=0, **kwargs): """Stack a list of tensors see full `docs `_ """ - return tf.stack(tensor_list, axis=axis, **kwargs) + return Kops.stack(tensor_list, axis=axis) def concatenate(tensor_list, axis=-1, target_shape=None, name=None): @@ -266,77 +206,20 @@ def concatenate(tensor_list, axis=-1, target_shape=None, name=None): Concatenates a list of numbers or tensor into a bigger tensor If the target shape is given, the output is reshaped to said shape """ - concatenated_tensor = concatenate_function(tensor_list, axis=axis) + concatenated_tensor = Kops.concatenate(tensor_list, axis=axis) if target_shape is None: return concatenated_tensor return K.reshape(concatenated_tensor, target_shape) -def einsum(equation, *args, **kwargs): - """ - Computes the tensor product using einsum - See full `docs `_ - """ - return tf.einsum(equation, *args, **kwargs) - - -def tensor_product(*args, **kwargs): - """ - Computes the tensordot product between tensor_x and tensor_y - See full `docs `_ - """ - return tf.tensordot(*args, **kwargs) - - -@tf.function -def pow(tensor, power): - """ - Computes the power of the tensor - """ - return tf.pow(tensor, power) - - -@tf.function(reduce_retracing=True) -def op_log(o_tensor, **kwargs): - """ - Computes the logarithm of the input - """ - return K.log(o_tensor) - - -@tf.function -def sum(*args, **kwargs): - """ - Computes the sum of the elements of the tensor - see full `docs `_ - """ - return K.sum(*args, **kwargs) - - -def split(*args, **kwargs): - """ - Splits the tensor on the selected axis - see full `docs `_ - """ - return tf.split(*args, **kwargs) - - def scatter_to_one(values, indices, output_shape): """ Like scatter_nd initialized to one instead of zero see full `docs `_ """ - ones = numpy_to_tensor(np.ones(output_shape)) - return tf.tensor_scatter_nd_update(ones, indices, values) - - -def op_subtract(inputs, **kwargs): - """ - Computes the difference between two tensors. - see full `docs `_ - """ - return keras_subtract(inputs, **kwargs) + ones = Kops.ones(output_shape) + return Kops.scatter_update(ones, indices, values) def swapaxes(tensor, source, destination): @@ -344,22 +227,66 @@ def swapaxes(tensor, source, destination): Moves the axis of the tensor from source to destination, as in numpy.swapaxes. see full `docs `_ """ - indices = list(range(tensor.shape.rank)) + rank = len(tensor.shape) + indices = list(range(rank)) if source < 0: - source += tensor.shape.rank + source += rank if destination < 0: - destination += tensor.shape.rank + destination += rank indices[source], indices[destination] = indices[destination], indices[source] - return tf.transpose(tensor, indices) + return Kops.transpose(tensor, indices) + +def elu(x, alpha=1.0, **kwargs): + new_layer = ELU(alpha=alpha, **kwargs) + return new_layer(x) -@tf.function -def backend_function(fun_name, *args, **kwargs): + +def tensor_splitter(ishape, split_sizes, axis=2, name="splitter"): """ - Wrapper to call non-explicitly implemented backend functions by name: (``fun_name``) - see full `docs `_ for some possibilities + Generates a Lambda layer to apply the split operation to a given tensor shape. + This wrapper cannot split along the batch index (axis=0). + + Parameters + ---------- + ishape: list(int) + input shape of the tensor that will be split + split_sizes: list(int) + size of each chunk + axis: int + axis along which the split will be applied + name: str + name of the layer + Returns + ------- + sp_layer: layer + a keras layer that applies the split operation upon call """ - fun = getattr(K, fun_name) - return fun(*args, **kwargs) + if axis < 1: + raise ValueError("tensor_splitter wrapper can only split along non-batch dimensions") + + # Check that we can indeed split this + if ishape[axis] != np.sum(split_sizes): + raise ValueError( + f"Cannot split tensor of shape {ishape} along axis {axis} in chunks of {split_sizes}" + ) + + # Output shape of each split + oshapes = [] + # Indices at which to put the splits + # NB: tensorflow's split function would've taken the split_sizes directly + # keras instead takes the index at where to split + indices = [] + current_idx = 0 + + for xsize in split_sizes: + current_idx += xsize + indices.append(current_idx) + oshapes.append((*ishape[1:axis], xsize, *ishape[axis + 1 :])) + + sp_layer = keras_Lambda( + lambda x: Kops.split(x, indices, axis=axis), output_shape=oshapes, name=name + ) + return sp_layer diff --git a/n3fit/src/n3fit/checks.py b/n3fit/src/n3fit/checks.py index 32c2d26acf..ee7cdaee43 100644 --- a/n3fit/src/n3fit/checks.py +++ b/n3fit/src/n3fit/checks.py @@ -159,6 +159,14 @@ def check_dropout(parameters): def check_tensorboard(tensorboard): """Check that the tensorbard callback can be enabled correctly""" if tensorboard is not None: + # Check that Tensorflow is installed + try: + import tensorflow + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "The tensorboard callback requires `tensorflow` to be installed" + ) from e + weight_freq = tensorboard.get("weight_freq", 0) if weight_freq < 0: raise CheckError( diff --git a/n3fit/src/n3fit/io/writer.py b/n3fit/src/n3fit/io/writer.py index 053fc6d229..7b6e40e140 100644 --- a/n3fit/src/n3fit/io/writer.py +++ b/n3fit/src/n3fit/io/writer.py @@ -394,23 +394,24 @@ def jsonfit( def version(): """Generates a dictionary with misc version info for this run""" versions = {} + try: - # Wrap tf in try-except block as it could possible to run n3fit without tf - import tensorflow as tf - from tensorflow.python.framework import test_util - - versions["keras"] = tf.keras.__version__ - mkl = test_util.IsMklEnabled() - versions["tensorflow"] = f"{tf.__version__}, mkl={mkl}" - except ImportError: - versions["tensorflow"] = "Not available" - versions["keras"] = "Not available" - except AttributeError: - # Check for MKL was only recently introduced and is not part of the official API - versions["tensorflow"] = f"{tf.__version__}, mkl=??" + import keras + + versions["keras"] = f"{keras.__version__} backend={keras.backend()}" + + if keras.backend.backend() == "tensorflow": + import tensorflow as tf + + versions["tensorflow"] = tf.__version__ + elif keras.backend.backend() == "torch": + import torch + + versions["torch"] == torch.__version__ except: # We don't want _any_ uncaught exception to crash the whole program at this point pass + versions["numpy"] = np.__version__ versions["nnpdf"] = n3fit.__version__ try: diff --git a/n3fit/src/n3fit/layers/DY.py b/n3fit/src/n3fit/layers/DY.py index f05416c5e4..94c982a391 100644 --- a/n3fit/src/n3fit/layers/DY.py +++ b/n3fit/src/n3fit/layers/DY.py @@ -86,7 +86,7 @@ def compute_dy_observable_many_replica(pdf, padded_fk): """ pdfa = pdf[1] pdfb = pdf[0] - + temp = op.einsum('nxfyg, bryg -> brnxf', padded_fk, pdfa) return op.einsum('brnxf, brxf -> brn', temp, pdfb) @@ -96,11 +96,13 @@ def compute_dy_observable_one_replica(pdf, mask_and_fk): Same operations as above but a specialized implementation that is more efficient for 1 replica, masking the PDF rather than the fk table. """ + # mask: (channels, flavs_b, flavs_a) Ffg + # fk: (npoints, channels, x_a, x_b) nFyx mask, fk = mask_and_fk # Retrieve the two PDFs (which may potentially be coming from different initial states) # Since this is the one-replica function, remove the batch and replica dimension - pdfb = pdf[0][0][0] # xf - pdfa = pdf[1][0][0] # yg + pdfb = pdf[0][0][0] # (x_b, flavs_b) xf + pdfa = pdf[1][0][0] # (x_a, flavs_a) yg # TODO: check which PDF must go first in case of different initial states!!! mask_x_pdf = op.tensor_product(mask, pdfa, axes=[(2,), (1,)]) # Ffg, yg -> Ffy diff --git a/n3fit/src/n3fit/layers/losses.py b/n3fit/src/n3fit/layers/losses.py index b33547a6ce..ee6162a8d4 100644 --- a/n3fit/src/n3fit/layers/losses.py +++ b/n3fit/src/n3fit/layers/losses.py @@ -160,7 +160,7 @@ def __init__(self, alpha=1e-7, **kwargs): super().__init__(**kwargs) def apply_loss(self, y_pred): - loss = op.backend_function("elu", -y_pred, alpha=self.alpha) + loss = op.elu(-y_pred, alpha=self.alpha) # Sum over the batch and the datapoints return op.sum(loss, axis=[0, -1]) @@ -180,6 +180,6 @@ class LossIntegrability(LossLagrange): """ def apply_loss(self, y_pred): - y = op.backend_function("square", y_pred) + y = y_pred * y_pred # Sum over the batch and the datapoints return op.sum(y, axis=[0, -1]) diff --git a/n3fit/src/n3fit/layers/mask.py b/n3fit/src/n3fit/layers/mask.py index 3ed007a18f..d89a8942af 100644 --- a/n3fit/src/n3fit/layers/mask.py +++ b/n3fit/src/n3fit/layers/mask.py @@ -1,4 +1,4 @@ -from numpy import count_nonzero +import numpy as np from n3fit.backends import MetaLayer from n3fit.backends import operations as op @@ -23,12 +23,14 @@ class Mask(MetaLayer): """ def __init__(self, bool_mask=None, c=None, **kwargs): + self._raw_mask = bool_mask + self._flattened_indices = None if bool_mask is None: self.mask = None self.last_dim = -1 else: self.mask = op.numpy_to_tensor(bool_mask, dtype=bool) - self.last_dim = count_nonzero(bool_mask[0, ...]) + self.last_dim = np.count_nonzero(bool_mask[0, ...]) self.c = c self.masked_output_shape = None super().__init__(**kwargs) @@ -40,9 +42,15 @@ def build(self, input_shape): # Make sure reshape will succeed: set the last dimension to the unmasked data length and before-last to # the number of replicas if self.mask is not None: + + # Prepare the indices to mask + indices = np.where(self._raw_mask) + # The batch dimension can be ignored + nreps = self.mask.shape[-2] + self._flattened_indices = np.ravel_multi_index(indices, self._raw_mask.shape) self.masked_output_shape = [-1 if d is None else d for d in input_shape] self.masked_output_shape[-1] = self.last_dim - self.masked_output_shape[-2] = self.mask.shape[-2] + self.masked_output_shape[-2] = nreps super().build(input_shape) def call(self, ret): @@ -58,7 +66,8 @@ def call(self, ret): Tensor of shape (batch_size, n_replicas, n_features) """ if self.mask is not None: - ret = op.boolean_mask(ret, self.mask, axis=1, target_shape=self.masked_output_shape) + ret = op.gather(op.flatten(ret), self._flattened_indices) + ret = op.reshape(ret, self.masked_output_shape) if self.c is not None: ret = ret * self.kernel return ret diff --git a/n3fit/src/n3fit/layers/msr_normalization.py b/n3fit/src/n3fit/layers/msr_normalization.py index 7695d4f11f..5159628c0d 100644 --- a/n3fit/src/n3fit/layers/msr_normalization.py +++ b/n3fit/src/n3fit/layers/msr_normalization.py @@ -194,6 +194,7 @@ def call(self, pdf_integrated, photon_integral): numerators += [self.vsr_factors] numerators = op.concatenate(numerators, axis=0) + divisors = op.gather(y, self.divisor_indices, axis=0) # Fill in the rest of the flavours with 1 diff --git a/n3fit/src/n3fit/layers/observable.py b/n3fit/src/n3fit/layers/observable.py index 8945cc4da4..14a7c8cd15 100644 --- a/n3fit/src/n3fit/layers/observable.py +++ b/n3fit/src/n3fit/layers/observable.py @@ -89,7 +89,7 @@ def __init__( operation_name="NULL", nfl=14, n_replicas=1, - **kwargs + **kwargs, ): super(MetaLayer, self).__init__(**kwargs) @@ -178,7 +178,10 @@ def call(self, pdf): rank 3 tensor (batchsize, replicas, ndata) """ if self.splitting: - pdfs = op.split(pdf, self.splitting, axis=2) + splitter = op.tensor_splitter( + pdf.shape, self.splitting, axis=2, name=f"pdf_splitter_{self.name}" + ) + pdfs = splitter(pdf) else: pdfs = [pdf] * len(self.padded_fk_tables) @@ -222,7 +225,7 @@ def compute_float_mask(bool_mask): """ # Create a tensor with the shape (**bool_mask.shape, num_active_flavours) masked_to_full = [] - for idx in np.argwhere(bool_mask): + for idx in np.argwhere(op.tensor_to_numpy_or_python(bool_mask)): temp_matrix = np.zeros(bool_mask.shape) temp_matrix[tuple(idx)] = 1 masked_to_full.append(temp_matrix) diff --git a/n3fit/src/n3fit/model_gen.py b/n3fit/src/n3fit/model_gen.py index 852f93caf3..868409f489 100644 --- a/n3fit/src/n3fit/model_gen.py +++ b/n3fit/src/n3fit/model_gen.py @@ -99,11 +99,8 @@ def _generate_experimental_layer(self, pdf): the input PDF is evaluated in all points that the experiment needs and needs to be split """ if len(self.dataset_xsizes) > 1: - splitting_layer = op.as_layer( - op.split, - op_args=[self.dataset_xsizes], - op_kwargs={"axis": 2}, - name=f"{self.name}_split", + splitting_layer = op.tensor_splitter( + pdf.shape, self.dataset_xsizes, axis=2, name=f"{self.name}_split" ) sp_pdf = splitting_layer(pdf) output_layers = [obs(p) for obs, p in zip(self.observables, sp_pdf)] diff --git a/n3fit/src/n3fit/model_trainer.py b/n3fit/src/n3fit/model_trainer.py index d92e7cf51d..d864d2c6e5 100644 --- a/n3fit/src/n3fit/model_trainer.py +++ b/n3fit/src/n3fit/model_trainer.py @@ -16,7 +16,7 @@ import numpy as np from n3fit import model_gen -from n3fit.backends import NN_LAYER_ALL_REPLICAS, MetaModel, callbacks, clear_backend_state +from n3fit.backends import NN_LAYER_ALL_REPLICAS, Lambda, MetaModel, callbacks, clear_backend_state from n3fit.backends import operations as op from n3fit.hyper_optimization.hyper_scan import HYPEROPT_STATUSES import n3fit.hyper_optimization.penalties @@ -40,6 +40,9 @@ # Each how many epochs do we increase the integrability Lagrange Multiplier PUSH_INTEGRABILITY_EACH = 100 +# Final number of flavours +FLAVOURS = 14 + # See ModelTrainer::_xgrid_generation for the definition of each field and how they are generated InputInfo = namedtuple("InputInfo", ["input", "split", "idx"]) @@ -354,11 +357,13 @@ def _xgrid_generation(self): input_arr = self._scaler(input_arr) input_layer = op.numpy_to_input(input_arr, name="pdf_input") - # The PDF model will be called with a concatenation of all inputs - # now the output needs to be splitted so that each experiment takes its corresponding input - sp_ar = [[i.shape[1] for i in inputs_unique]] - sp_kw = {"axis": 2} - sp_layer = op.as_layer(op.split, op_args=sp_ar, op_kwargs=sp_kw, name="pdf_split") + # The PDF model is called with a concatenation of all inputs + # however, each output layer might require a different subset, this is achieved by + # splitting back the output + # Input shape: (batch size, replicas, input array, flavours) + ishape = (1, len(self.replicas), input_arr.shape[0], FLAVOURS) + xsizes = [i.shape[1] for i in inputs_unique] + sp_layer = op.tensor_splitter(ishape, xsizes, axis=2, name="splitter") return InputInfo(input_layer, sp_layer, inputs_idx) @@ -936,8 +941,10 @@ def hyperparametrizable(self, params): ) if photons: - if self._scaler: # select only the non-scaled input - pdf_model.get_layer("add_photon").register_photon(xinput.input.tensor_content[:,:,1:]) + if self._scaler: # select only the non-scaled input + pdf_model.get_layer("add_photon").register_photon( + xinput.input.tensor_content[:, :, 1:] + ) else: pdf_model.get_layer("add_photon").register_photon(xinput.input.tensor_content) diff --git a/n3fit/src/n3fit/msr.py b/n3fit/src/n3fit/msr.py index a66e03a3fe..721eb6b38d 100644 --- a/n3fit/src/n3fit/msr.py +++ b/n3fit/src/n3fit/msr.py @@ -84,7 +84,9 @@ def generate_msr_model_and_grid( # 3. Prepare the pdf for integration by dividing by x pdf_integrand = Lambda( - lambda x_pdf: op.batchit(x_pdf[0], batch_dimension=1) * x_pdf[1], name="pdf_integrand" + lambda x_pdf: op.batchit(x_pdf[0], batch_dimension=1) * x_pdf[1], + name="pdf_integrand", + output_shape=pdf_xgrid_integration.shape[1:], )([x_divided, pdf_xgrid_integration]) # 4. Integrate the pdf diff --git a/n3fit/src/n3fit/performfit.py b/n3fit/src/n3fit/performfit.py index 04703ef924..7e91c1b5ca 100644 --- a/n3fit/src/n3fit/performfit.py +++ b/n3fit/src/n3fit/performfit.py @@ -3,11 +3,8 @@ """ # Backend-independent imports -import copy import logging -import numpy as np - import n3fit.checks from n3fit.vpinterface import N3PDF diff --git a/n3fit/src/n3fit/tests/test_backend.py b/n3fit/src/n3fit/tests/test_backend.py index eaae5667c8..9e434eaf8b 100644 --- a/n3fit/src/n3fit/tests/test_backend.py +++ b/n3fit/src/n3fit/tests/test_backend.py @@ -2,8 +2,11 @@ This module tests the mathematical functions in the n3fit backend and ensures they do the same thing as their numpy counterparts """ + import operator + import numpy as np + from n3fit.backends import operations as op # General parameters @@ -24,14 +27,14 @@ def are_equal(result, reference, threshold=THRESHOLD): - """ checks the difference between array `reference` and tensor `result` is - below `threshold` for all elements """ - res = op.evaluate(result) + """checks the difference between array `reference` and tensor `result` is + below `threshold` for all elements""" + res = op.tensor_to_numpy_or_python(result) assert np.allclose(res, reference, atol=threshold) def numpy_check(backend_op, python_op, mode="same"): - """ Receives a backend operation (`backend_op`) and a python operation + """Receives a backend operation (`backend_op`) and a python operation `python_op` and asserts that, applied to two random arrays, the result is the same. The option `mode` selects the two arrays to be tested and accepts the following @@ -53,7 +56,28 @@ def numpy_check(backend_op, python_op, mode="same"): arrays = [ARR1, ARR2, ARR1, ARR1] elif mode == "twenty": tensors = [T1, T2, T1, T1, T1, T1, T1, T1, T1, T1, T1, T2, T1, T1, T1, T1, T1, T1, T1, T1] - arrays = [ARR1, ARR2, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR2, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1] + arrays = [ + ARR1, + ARR2, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR2, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ARR1, + ] elif mode == "ten": tensors = [T1, T2, T1, T1, T1, T1, T1, T1, T1, T1] arrays = [ARR1, ARR2, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1, ARR1] @@ -98,22 +122,21 @@ def test_c_to_py_fun(): numpy_check(op_smp, reference, "four") # COM op_com = op.c_to_py_fun("COM") - reference = lambda x, y, z, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t : (x + y + z + d + e + f + g + h + i + j) / (k + l + m + n + o + p + q + r + s + t) + reference = lambda x, y, z, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t: ( + x + y + z + d + e + f + g + h + i + j + ) / (k + l + m + n + o + p + q + r + s + t) numpy_check(op_com, reference, "twenty") # SMT op_smt = op.c_to_py_fun("SMT") - reference = lambda x, y, z, d, e, f, g, h, i, j : (x + y + z + d + e + f + g + h + i + j) + reference = lambda x, y, z, d, e, f, g, h, i, j: (x + y + z + d + e + f + g + h + i + j) numpy_check(op_smt, reference, "ten") + # Tests operations def test_op_multiply(): numpy_check(op.op_multiply, operator.mul) -def test_op_multiply_dim(): - numpy_check(op.op_multiply_dim, operator.mul, mode="diff") - - def test_op_log(): numpy_check(op.op_log, np.log, mode='single') @@ -122,17 +145,11 @@ def test_flatten(): numpy_check(op.flatten, np.ndarray.flatten, mode=(T3, [ARR3])) -def test_boolean_mask(): - bools = np.random.randint(0, 2, DIM, dtype=bool) - np_result = ARR1[bools] - tf_bools = op.numpy_to_tensor(bools) - tf_result = op.boolean_mask(T1, tf_bools, axis=0) - are_equal(np_result, tf_result) - def test_tensor_product(): np_result = np.tensordot(ARR3, ARR1, axes=1) tf_result = op.tensor_product(T3, T1, axes=1) - are_equal(np_result, tf_result) + are_equal(tf_result, np_result) + def test_sum(): numpy_check(op.sum, np.sum, mode='single') diff --git a/n3fit/src/n3fit/tests/test_layers.py b/n3fit/src/n3fit/tests/test_layers.py index 8615414c2f..84ef8c8eaf 100644 --- a/n3fit/src/n3fit/tests/test_layers.py +++ b/n3fit/src/n3fit/tests/test_layers.py @@ -169,7 +169,7 @@ def test_DIS(): kp = op.numpy_to_tensor([[pdf]]) # add batch and replica dimension # generate the n3fit results result_tensor = obs_layer(kp) - result = op.evaluate(result_tensor) + result = op.tensor_to_numpy_or_python(result_tensor) # Compute the numpy version of this layer all_masks = obs_layer.all_masks if len(all_masks) < nfk: @@ -195,7 +195,7 @@ def test_DY(): kp = op.numpy_to_tensor([[pdf]]) # add batch and replica dimension # generate the n3fit results result_tensor = obs_layer(kp) - result = op.evaluate(result_tensor) + result = op.tensor_to_numpy_or_python(result_tensor) # Compute the numpy version of this layer all_masks = obs_layer.all_masks if len(all_masks) < nfk: diff --git a/pyproject.toml b/pyproject.toml index e9f4ce5d9a..9f10d7b934 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ reportengine = { git = "https://github.com/NNPDF/reportengine" } # Fit psutil = "*" tensorflow = "*" +keras = "^3.1" eko = "^0.14.1" joblib = "*" # Hyperopt @@ -97,6 +98,8 @@ fiatlux = {version = "*", optional = true} # without lhapdf pdfflow = {version = "^1.2.1", optional = true} lhapdf-management = {version = "^0.5", optional = true} +# torch +torch = {version = "*", optional = true} # Optional dependencies [tool.poetry.extras] @@ -104,6 +107,7 @@ tests = ["pytest", "pytest-mpl", "hypothesis"] docs = ["sphinxcontrib", "sphinx-rtd-theme", "sphinx", "tabulate"] qed = ["fiatlux"] nolha = ["pdfflow", "lhapdf-management"] +torch = ["torch"] [tool.poetry-dynamic-versioning] enable = true