Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the square_singlet activation function #2064

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions n3fit/src/n3fit/backends/keras_backend/base_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

In order to add custom activation functions, they must be added to
the `custom_activations` dictionary with the following structure:

'name of the activation' : function

The names of the layer and the activation function are the ones to be used in the n3fit runcard.
Expand All @@ -24,8 +24,9 @@
from tensorflow.keras.layers import LSTM, Concatenate
from tensorflow.keras.regularizers import l1_l2

from n3fit.backends import MetaLayer
from n3fit.backends.keras_backend.multi_dense import MultiDense
from .MetaLayer import MetaLayer
from .multi_dense import MultiDense
from .operations import concatenate_function


# Custom activation functions
Expand All @@ -34,6 +35,13 @@ def square_activation(x):
return x * x


def square_singlet(x):
"""Square the singlet sector
Defined as the two first values of the NN"""
singlet_squared = x[..., :2] ** 2
return concatenate_function([singlet_squared, x[..., 2:]], axis=-1)


def modified_tanh(x):
"""A non-saturating version of the tanh function"""
return math.abs(x) * nn.tanh(x)
Expand All @@ -46,6 +54,7 @@ def leaky_relu(x):

custom_activations = {
"square": square_activation,
"square_singlet": square_singlet,
"leaky_relu": leaky_relu,
"modified_tanh": modified_tanh,
}
Expand Down
45 changes: 35 additions & 10 deletions n3fit/src/n3fit/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os

from n3fit.hyper_optimization import penalties as penalties_module
from n3fit.hyper_optimization import rewards as rewards_module
from n3fit.hyper_optimization.rewards import IMPLEMENTED_LOSSES, IMPLEMENTED_STATS
from reportengine.checks import CheckError, make_argcheck
from validphys.pdfbases import check_basis
Expand Down Expand Up @@ -50,12 +49,20 @@ def check_existing_parameters(parameters):


def check_consistent_layers(parameters):
"""Checks that all layers have an activation function defined"""
"""Checks that all layers have an activation function defined
and that a final-activation function is not being used half-way through.
"""
final_activations = ["square_singlet"]
npl = len(parameters["nodes_per_layer"])
apl = len(parameters["activation_per_layer"])
act_per_layer = parameters["activation_per_layer"]
apl = len(act_per_layer)
if npl != apl:
raise CheckError(f"Number of layers ({npl}) does not match activation functions: {apl}")

for fin_act in final_activations:
if fin_act in act_per_layer[:-1]:
raise CheckError(f"The activation {fin_act} can only be used as last layer")


def check_stopping(parameters):
"""Checks whether the stopping-related options are sane:
Expand All @@ -70,8 +77,10 @@ def check_stopping(parameters):
raise CheckError(f"Needs to run at least 1 epoch, got: {epochs}")


def check_basis_with_layers(basis, parameters):
"""Check that the last layer matches the number of flavours defined in the runcard"""
def check_basis_with_layers(basis, validphys_basis, parameters):
"""Check that the last layer matches the number of flavours defined in the runcard.
And that the activation functions are compatible with the basis.
"""
number_of_flavours = len(basis)
last_layer = parameters["nodes_per_layer"][-1]
if number_of_flavours != last_layer:
Expand All @@ -80,6 +89,21 @@ def check_basis_with_layers(basis, parameters):
f" match the number of flavours: ({number_of_flavours})"
)

flavours = [i["fl"] for i in basis]
if parameters["activation_per_layer"][-1] == "square_singlet":
if not (("sng" in flavours) and ("g" in flavours)):
raise CheckError(
"square_singlet can only be used when `gluon` (g) and `singlet` (sng) are being fitted"
)
if (val := validphys_basis.indexes.get("sng")) > 1:
raise CheckError(
f"When using square_singlet, \\Sigma must be either element 0 or 1, found {val}"
)
if (val := validphys_basis.indexes.get("g")) > 1:
raise CheckError(
f"When using square_singlet, gluon must be either element 0 or 1, found {val}"
)


def check_optimizer(optimizer_dict):
"""Checks whether the optimizer setup is valid"""
Expand Down Expand Up @@ -176,13 +200,12 @@ def check_model_file(save, load):


@make_argcheck
def wrapper_check_NN(basis, tensorboard, save, load, parameters):
def wrapper_check_NN(tensorboard, save, load, parameters):
"""Wrapper function for all NN-related checks"""
check_tensorboard(tensorboard)
check_model_file(save, load)
check_existing_parameters(parameters)
check_consistent_layers(parameters)
check_basis_with_layers(basis, parameters)
check_stopping(parameters)
check_layer_type_implemented(parameters)
check_dropout(parameters)
Expand Down Expand Up @@ -359,7 +382,7 @@ def check_sumrules(sum_rules):

# Checks on the physics
@make_argcheck
def check_consistent_basis(sum_rules, fitbasis, basis, theoryid):
def check_consistent_basis(sum_rules, fitbasis, basis, theoryid, parameters):
"""Checks the fitbasis setup for inconsistencies
- Checks the sum rules can be imposed
- Correct flavours for the selected basis
Expand All @@ -381,14 +404,16 @@ def check_consistent_basis(sum_rules, fitbasis, basis, theoryid):
flavs.append(name)
# Finally check whether the basis considers or not charm
# Check that the basis given in the runcard is one of those defined in validphys.pdfbases
basis = check_basis(fitbasis, flavs)["basis"]
vp_basis = check_basis(fitbasis, flavs)["basis"]
# Now check that basis and theory id are consistent
has_c = basis.has_element("c") or basis.has_element("T15") or basis.has_element("cp")
has_c = vp_basis.has_element("c") or vp_basis.has_element("T15") or vp_basis.has_element("cp")
if theoryid.get_description()["IC"] and not has_c:
raise CheckError(f"{theoryid} (intrinsic charm) is incompatible with basis {fitbasis}")
if not theoryid.get_description()["IC"] and has_c:
raise CheckError(f"{theoryid} (perturbative charm) is incompatible with basis {fitbasis}")

check_basis_with_layers(basis, vp_basis, parameters)


@make_argcheck
def check_consistent_parallel(parameters, parallel_models, same_trvl_per_replica):
Expand Down
13 changes: 10 additions & 3 deletions n3fit/src/n3fit/tests/test_checks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
Test for the n3fit checks
"""

import pytest

from n3fit import checks
from n3fit.backends import MetaModel
from reportengine.checks import CheckError
from validphys.pdfbases import check_basis


def test_existing_parameters():
Expand Down Expand Up @@ -42,10 +43,16 @@ def test_check_stopping():

def test_check_basis_with_layer():
"""Test that it fails with layers that do not match flavours"""
flavs = ["g", "ubar"]
flavs = ["g", "u", "ubar", "d", "dbar"]
basis = [{"fl": i} for i in flavs]
layers = [4, 5, 9]
vp_basis = check_basis("flavour", flavs)["basis"]
with pytest.raises(CheckError):
checks.check_basis_with_layers(basis, vp_basis, {"nodes_per_layer": layers})
# Or when the wrong kind of basis is veing used
params = {"nodes_per_layer": [4, 5], "activation_per_layer": ["sigmoid", "square_singlet"]}
with pytest.raises(CheckError):
checks.check_basis_with_layers({"basis": flavs}, {"nodes_per_layer": layers})
checks.check_basis_with_layers(basis, vp_basis, params)


def test_check_optimizer():
Expand Down