Skip to content

Commit

Permalink
Merge pull request #2064 from NNPDF/square_singlet
Browse files Browse the repository at this point in the history
Add the square_singlet activation function
  • Loading branch information
scarlehoff authored May 7, 2024
2 parents afbe441 + 927d80a commit 124ae08
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 16 deletions.
15 changes: 12 additions & 3 deletions n3fit/src/n3fit/backends/keras_backend/base_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
In order to add custom activation functions, they must be added to
the `custom_activations` dictionary with the following structure:
'name of the activation' : function
The names of the layer and the activation function are the ones to be used in the n3fit runcard.
Expand All @@ -24,8 +24,9 @@
from tensorflow.keras.layers import LSTM, Concatenate
from tensorflow.keras.regularizers import l1_l2

from n3fit.backends import MetaLayer
from n3fit.backends.keras_backend.multi_dense import MultiDense
from .MetaLayer import MetaLayer
from .multi_dense import MultiDense
from .operations import concatenate_function


# Custom activation functions
Expand All @@ -34,6 +35,13 @@ def square_activation(x):
return x * x


def square_singlet(x):
"""Square the singlet sector
Defined as the two first values of the NN"""
singlet_squared = x[..., :2] ** 2
return concatenate_function([singlet_squared, x[..., 2:]], axis=-1)


def modified_tanh(x):
"""A non-saturating version of the tanh function"""
return math.abs(x) * nn.tanh(x)
Expand All @@ -46,6 +54,7 @@ def leaky_relu(x):

custom_activations = {
"square": square_activation,
"square_singlet": square_singlet,
"leaky_relu": leaky_relu,
"modified_tanh": modified_tanh,
}
Expand Down
45 changes: 35 additions & 10 deletions n3fit/src/n3fit/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os

from n3fit.hyper_optimization import penalties as penalties_module
from n3fit.hyper_optimization import rewards as rewards_module
from n3fit.hyper_optimization.rewards import IMPLEMENTED_LOSSES, IMPLEMENTED_STATS
from reportengine.checks import CheckError, make_argcheck
from validphys.pdfbases import check_basis
Expand Down Expand Up @@ -50,12 +49,20 @@ def check_existing_parameters(parameters):


def check_consistent_layers(parameters):
"""Checks that all layers have an activation function defined"""
"""Checks that all layers have an activation function defined
and that a final-activation function is not being used half-way through.
"""
final_activations = ["square_singlet"]
npl = len(parameters["nodes_per_layer"])
apl = len(parameters["activation_per_layer"])
act_per_layer = parameters["activation_per_layer"]
apl = len(act_per_layer)
if npl != apl:
raise CheckError(f"Number of layers ({npl}) does not match activation functions: {apl}")

for fin_act in final_activations:
if fin_act in act_per_layer[:-1]:
raise CheckError(f"The activation {fin_act} can only be used as last layer")


def check_stopping(parameters):
"""Checks whether the stopping-related options are sane:
Expand All @@ -70,8 +77,10 @@ def check_stopping(parameters):
raise CheckError(f"Needs to run at least 1 epoch, got: {epochs}")


def check_basis_with_layers(basis, parameters):
"""Check that the last layer matches the number of flavours defined in the runcard"""
def check_basis_with_layers(basis, validphys_basis, parameters):
"""Check that the last layer matches the number of flavours defined in the runcard.
And that the activation functions are compatible with the basis.
"""
number_of_flavours = len(basis)
last_layer = parameters["nodes_per_layer"][-1]
if number_of_flavours != last_layer:
Expand All @@ -80,6 +89,21 @@ def check_basis_with_layers(basis, parameters):
f" match the number of flavours: ({number_of_flavours})"
)

flavours = [i["fl"] for i in basis]
if parameters["activation_per_layer"][-1] == "square_singlet":
if not (("sng" in flavours) and ("g" in flavours)):
raise CheckError(
"square_singlet can only be used when `gluon` (g) and `singlet` (sng) are being fitted"
)
if (val := validphys_basis.indexes.get("sng")) > 1:
raise CheckError(
f"When using square_singlet, \\Sigma must be either element 0 or 1, found {val}"
)
if (val := validphys_basis.indexes.get("g")) > 1:
raise CheckError(
f"When using square_singlet, gluon must be either element 0 or 1, found {val}"
)


def check_optimizer(optimizer_dict):
"""Checks whether the optimizer setup is valid"""
Expand Down Expand Up @@ -176,13 +200,12 @@ def check_model_file(save, load):


@make_argcheck
def wrapper_check_NN(basis, tensorboard, save, load, parameters):
def wrapper_check_NN(tensorboard, save, load, parameters):
"""Wrapper function for all NN-related checks"""
check_tensorboard(tensorboard)
check_model_file(save, load)
check_existing_parameters(parameters)
check_consistent_layers(parameters)
check_basis_with_layers(basis, parameters)
check_stopping(parameters)
check_layer_type_implemented(parameters)
check_dropout(parameters)
Expand Down Expand Up @@ -359,7 +382,7 @@ def check_sumrules(sum_rules):

# Checks on the physics
@make_argcheck
def check_consistent_basis(sum_rules, fitbasis, basis, theoryid):
def check_consistent_basis(sum_rules, fitbasis, basis, theoryid, parameters):
"""Checks the fitbasis setup for inconsistencies
- Checks the sum rules can be imposed
- Correct flavours for the selected basis
Expand All @@ -381,14 +404,16 @@ def check_consistent_basis(sum_rules, fitbasis, basis, theoryid):
flavs.append(name)
# Finally check whether the basis considers or not charm
# Check that the basis given in the runcard is one of those defined in validphys.pdfbases
basis = check_basis(fitbasis, flavs)["basis"]
vp_basis = check_basis(fitbasis, flavs)["basis"]
# Now check that basis and theory id are consistent
has_c = basis.has_element("c") or basis.has_element("T15") or basis.has_element("cp")
has_c = vp_basis.has_element("c") or vp_basis.has_element("T15") or vp_basis.has_element("cp")
if theoryid.get_description()["IC"] and not has_c:
raise CheckError(f"{theoryid} (intrinsic charm) is incompatible with basis {fitbasis}")
if not theoryid.get_description()["IC"] and has_c:
raise CheckError(f"{theoryid} (perturbative charm) is incompatible with basis {fitbasis}")

check_basis_with_layers(basis, vp_basis, parameters)


@make_argcheck
def check_consistent_parallel(parameters, parallel_models, same_trvl_per_replica):
Expand Down
13 changes: 10 additions & 3 deletions n3fit/src/n3fit/tests/test_checks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
Test for the n3fit checks
"""

import pytest

from n3fit import checks
from n3fit.backends import MetaModel
from reportengine.checks import CheckError
from validphys.pdfbases import check_basis


def test_existing_parameters():
Expand Down Expand Up @@ -42,10 +43,16 @@ def test_check_stopping():

def test_check_basis_with_layer():
"""Test that it fails with layers that do not match flavours"""
flavs = ["g", "ubar"]
flavs = ["g", "u", "ubar", "d", "dbar"]
basis = [{"fl": i} for i in flavs]
layers = [4, 5, 9]
vp_basis = check_basis("flavour", flavs)["basis"]
with pytest.raises(CheckError):
checks.check_basis_with_layers(basis, vp_basis, {"nodes_per_layer": layers})
# Or when the wrong kind of basis is veing used
params = {"nodes_per_layer": [4, 5], "activation_per_layer": ["sigmoid", "square_singlet"]}
with pytest.raises(CheckError):
checks.check_basis_with_layers({"basis": flavs}, {"nodes_per_layer": layers})
checks.check_basis_with_layers(basis, vp_basis, params)


def test_check_optimizer():
Expand Down

0 comments on commit 124ae08

Please sign in to comment.