From d7d29bbfd151d46b03ded1edfd8f6cf10968c143 Mon Sep 17 00:00:00 2001 From: juacrumar Date: Mon, 22 Apr 2024 17:21:38 +0200 Subject: [PATCH 1/3] add the square_singlet activation function --- .../n3fit/backends/keras_backend/base_layers.py | 15 ++++++++++++--- n3fit/src/n3fit/checks.py | 12 ++++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/n3fit/src/n3fit/backends/keras_backend/base_layers.py b/n3fit/src/n3fit/backends/keras_backend/base_layers.py index 6cf545e06a..a68f058c5a 100644 --- a/n3fit/src/n3fit/backends/keras_backend/base_layers.py +++ b/n3fit/src/n3fit/backends/keras_backend/base_layers.py @@ -11,7 +11,7 @@ In order to add custom activation functions, they must be added to the `custom_activations` dictionary with the following structure: - + 'name of the activation' : function The names of the layer and the activation function are the ones to be used in the n3fit runcard. @@ -24,8 +24,9 @@ from tensorflow.keras.layers import LSTM, Concatenate from tensorflow.keras.regularizers import l1_l2 -from n3fit.backends import MetaLayer -from n3fit.backends.keras_backend.multi_dense import MultiDense +from .MetaLayer import MetaLayer +from .multi_dense import MultiDense +from .operations import concatenate_function # Custom activation functions @@ -34,6 +35,13 @@ def square_activation(x): return x * x +def square_singlet(x): + """Square the singlet sector + Defined as the two first values of the NN""" + singlet_squared = x[..., :2] ** 2 + return concatenate_function([singlet_squared, x[..., 2:]], axis=-1) + + def modified_tanh(x): """A non-saturating version of the tanh function""" return math.abs(x) * nn.tanh(x) @@ -46,6 +54,7 @@ def leaky_relu(x): custom_activations = { "square": square_activation, + "square_singlet": square_singlet, "leaky_relu": leaky_relu, "modified_tanh": modified_tanh, } diff --git a/n3fit/src/n3fit/checks.py b/n3fit/src/n3fit/checks.py index cea5797550..409d19d56d 100644 --- a/n3fit/src/n3fit/checks.py +++ b/n3fit/src/n3fit/checks.py @@ -50,12 +50,20 @@ def check_existing_parameters(parameters): def check_consistent_layers(parameters): - """Checks that all layers have an activation function defined""" + """Checks that all layers have an activation function defined + and that a final-activation function is not being used half-way through. + """ + final_activations = ["square_singlet"] npl = len(parameters["nodes_per_layer"]) - apl = len(parameters["activation_per_layer"]) + act_per_layer = parameters["activation_per_layer"] + apl = len(act_per_layer) if npl != apl: raise CheckError(f"Number of layers ({npl}) does not match activation functions: {apl}") + for fin_act in final_activations: + if fin_act in act_per_layer[:-1]: + raise CheckError(f"The activation {fin_act} can only be used as last layer") + def check_stopping(parameters): """Checks whether the stopping-related options are sane: From ae08024fa755a428a2bc2d65bdbd6c21456bf98e Mon Sep 17 00:00:00 2001 From: juacrumar Date: Tue, 7 May 2024 10:31:23 +0200 Subject: [PATCH 2/3] add checks for square singlet --- n3fit/src/n3fit/checks.py | 32 ++++++++++++++++++++++------ n3fit/src/n3fit/tests/test_checks.py | 13 ++++++++--- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/n3fit/src/n3fit/checks.py b/n3fit/src/n3fit/checks.py index 409d19d56d..0cb2e8c593 100644 --- a/n3fit/src/n3fit/checks.py +++ b/n3fit/src/n3fit/checks.py @@ -78,8 +78,10 @@ def check_stopping(parameters): raise CheckError(f"Needs to run at least 1 epoch, got: {epochs}") -def check_basis_with_layers(basis, parameters): - """Check that the last layer matches the number of flavours defined in the runcard""" +def check_basis_with_layers(basis, validphys_basis, parameters): + """Check that the last layer matches the number of flavours defined in the runcard. + And that the activation functions are compatible with the basis. + """ number_of_flavours = len(basis) last_layer = parameters["nodes_per_layer"][-1] if number_of_flavours != last_layer: @@ -88,6 +90,21 @@ def check_basis_with_layers(basis, parameters): f" match the number of flavours: ({number_of_flavours})" ) + flavours = [i["fl"] for i in basis] + if parameters["activation_per_layer"][-1] == "square_singlet": + if not (("sng" in flavours) and ("g" in flavours)): + raise CheckError( + "square_singlet can only be used when `gluon` (g) and `singlet` (sng) are being fitted" + ) + if (val := validphys_basis.indexes.get("sng")) > 1: + raise CheckError( + f"When using square_singlet, \\Sigma must be either element 0 or 1, found {val}" + ) + if (val := validphys_basis.indexes.get("g")) > 1: + raise CheckError( + f"When using square_singlet, gluon must be either element 0 or 1, found {val}" + ) + def check_optimizer(optimizer_dict): """Checks whether the optimizer setup is valid""" @@ -184,13 +201,12 @@ def check_model_file(save, load): @make_argcheck -def wrapper_check_NN(basis, tensorboard, save, load, parameters): +def wrapper_check_NN(tensorboard, save, load, parameters): """Wrapper function for all NN-related checks""" check_tensorboard(tensorboard) check_model_file(save, load) check_existing_parameters(parameters) check_consistent_layers(parameters) - check_basis_with_layers(basis, parameters) check_stopping(parameters) check_layer_type_implemented(parameters) check_dropout(parameters) @@ -367,7 +383,7 @@ def check_sumrules(sum_rules): # Checks on the physics @make_argcheck -def check_consistent_basis(sum_rules, fitbasis, basis, theoryid): +def check_consistent_basis(sum_rules, fitbasis, basis, theoryid, parameters): """Checks the fitbasis setup for inconsistencies - Checks the sum rules can be imposed - Correct flavours for the selected basis @@ -389,14 +405,16 @@ def check_consistent_basis(sum_rules, fitbasis, basis, theoryid): flavs.append(name) # Finally check whether the basis considers or not charm # Check that the basis given in the runcard is one of those defined in validphys.pdfbases - basis = check_basis(fitbasis, flavs)["basis"] + vp_basis = check_basis(fitbasis, flavs)["basis"] # Now check that basis and theory id are consistent - has_c = basis.has_element("c") or basis.has_element("T15") or basis.has_element("cp") + has_c = vp_basis.has_element("c") or vp_basis.has_element("T15") or vp_basis.has_element("cp") if theoryid.get_description()["IC"] and not has_c: raise CheckError(f"{theoryid} (intrinsic charm) is incompatible with basis {fitbasis}") if not theoryid.get_description()["IC"] and has_c: raise CheckError(f"{theoryid} (perturbative charm) is incompatible with basis {fitbasis}") + check_basis_with_layers(basis, vp_basis, parameters) + @make_argcheck def check_consistent_parallel(parameters, parallel_models, same_trvl_per_replica): diff --git a/n3fit/src/n3fit/tests/test_checks.py b/n3fit/src/n3fit/tests/test_checks.py index 9efeb2e22f..da3d325649 100644 --- a/n3fit/src/n3fit/tests/test_checks.py +++ b/n3fit/src/n3fit/tests/test_checks.py @@ -1,11 +1,12 @@ """ Test for the n3fit checks """ + import pytest from n3fit import checks -from n3fit.backends import MetaModel from reportengine.checks import CheckError +from validphys.pdfbases import check_basis def test_existing_parameters(): @@ -42,10 +43,16 @@ def test_check_stopping(): def test_check_basis_with_layer(): """Test that it fails with layers that do not match flavours""" - flavs = ["g", "ubar"] + flavs = ["g", "u", "ubar", "d", "dbar"] + basis = [{"fl": i} for i in flavs] layers = [4, 5, 9] + vp_basis = check_basis("flavour", flavs)["basis"] + with pytest.raises(CheckError): + checks.check_basis_with_layers(basis, vp_basis, {"nodes_per_layer": layers}) + # Or when the wrong kind of basis is veing used + params = {"nodes_per_layer": [4, 5], "activation_per_layer": ["sigmoid", "square_singlet"]} with pytest.raises(CheckError): - checks.check_basis_with_layers({"basis": flavs}, {"nodes_per_layer": layers}) + checks.check_basis_with_layers(basis, vp_basis, params) def test_check_optimizer(): From 927d80a5dbd3bb051cceedf41a84f5d59fcaa214 Mon Sep 17 00:00:00 2001 From: RoyStegeman Date: Tue, 7 May 2024 10:35:54 +0100 Subject: [PATCH 3/3] remove unused import --- n3fit/src/n3fit/checks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/n3fit/src/n3fit/checks.py b/n3fit/src/n3fit/checks.py index 0cb2e8c593..607d838e55 100644 --- a/n3fit/src/n3fit/checks.py +++ b/n3fit/src/n3fit/checks.py @@ -7,7 +7,6 @@ import os from n3fit.hyper_optimization import penalties as penalties_module -from n3fit.hyper_optimization import rewards as rewards_module from n3fit.hyper_optimization.rewards import IMPLEMENTED_LOSSES, IMPLEMENTED_STATS from reportengine.checks import CheckError, make_argcheck from validphys.pdfbases import check_basis