Skip to content

Commit

Permalink
start porting changes required for polarised fits
Browse files Browse the repository at this point in the history
  • Loading branch information
Radonirinaunimi committed Mar 4, 2024
1 parent 1623808 commit 3eedbb7
Show file tree
Hide file tree
Showing 25 changed files with 433 additions and 132 deletions.
15 changes: 15 additions & 0 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,21 @@ def pow(tensor, power):
return tf.pow(tensor, power)


@tf.function
def absolute(tensor):
"""
Compute the absolute value of a tensor
"""
return K.abs(tensor)


def multiply_minusone(tensor):
"""
Multiply the elements of a given tensor by (-1)
"""
return keras_Lambda(lambda x: -1 * x)(tensor)


@tf.function(reduce_retracing=True)
def op_log(o_tensor, **kwargs):
"""
Expand Down
14 changes: 10 additions & 4 deletions n3fit/src/n3fit/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
import numbers
import os

import numpy as np

from n3fit.hyper_optimization import penalties as penalties_module
from n3fit.hyper_optimization import rewards as rewards_module
from reportengine.checks import CheckError, make_argcheck
from validphys.core import PDF
from validphys.pdfbases import check_basis

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -71,6 +68,15 @@ def check_stopping(parameters):
raise CheckError(f"Needs to run at least 1 epoch, got: {epochs}")


@make_argcheck
def check_polarised(fitbasis, fitting):
"""Checks that if the polarised basis is used, then the necessary entries
are specified correctly.
"""
if "POL" in fitbasis and fitting.get("sum_rules") != "TSR":
raise CheckError("'sum_rules' needs to be 'TSR' for polarised fits.")


def check_basis_with_layers(basis, parameters):
"""Check that the last layer matches the number of flavours defined in the runcard"""
number_of_flavours = len(basis)
Expand Down Expand Up @@ -334,7 +340,7 @@ def check_sumrules(sum_rules):
"""Checks that the chosen option for the sum rules are sensible"""
if isinstance(sum_rules, bool):
return
accepted_options = ["ALL", "MSR", "VSR", "ALLBUTCSR"]
accepted_options = ["ALL", "MSR", "VSR", "TSR", "ALLBUTCSR"]
if sum_rules.upper() in accepted_options:
return
raise CheckError(f"The only accepted options for the sum rules are: {accepted_options}")
Expand Down
24 changes: 16 additions & 8 deletions n3fit/src/n3fit/layers/DIS.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

from .observable import Observable

POS_POLSD_IDEX = [0, 2] # Polarised POS Datasets
POS_UNPOL_IDEX = [1, 3] # Unpolarised POS Datasets


class DIS(Observable):
"""
Expand Down Expand Up @@ -64,16 +67,21 @@ def call(self, pdf):
raise ValueError("DIS layer call with a dataset that needs more than one xgrid?")

results = []
# Separate the two possible paths this layer can take
if self.many_masks:
for mask, fktable in zip(self.all_masks, self.fktables):
for idx, fktable in enumerate(self.fktables):
mask = self.all_masks[idx] if self.many_masks else self.all_masks[0]
if self.is_polarised_pos() and idx in POS_UNPOL_IDEX:
# Convolute the FK tables with the pre-computed Unpolarised PDFs
pdf_masked = op.boolean_mask(self.computed_pdfs[idx], mask, axis=3)
else:
pdf_masked = op.boolean_mask(pdf, mask, axis=3)

if self.is_polarised_pos() and idx in POS_POLSD_IDEX:
# Compute the absolute value of `x \Delta f(x)` and multiply with (-1)
res = op.tensor_product(pdf_masked, fktable, axes=[(2, 3), (2, 1)])
results.append(res)
else:
pdf_masked = op.boolean_mask(pdf, self.all_masks[0], axis=3)
for fktable in self.fktables:
res = op.multiply_minusone(op.absolute(res))
else:
res = op.tensor_product(pdf_masked, fktable, axes=[(2, 3), (2, 1)])
results.append(res)

results.append(res)

return self.operation(results)
2 changes: 1 addition & 1 deletion n3fit/src/n3fit/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class LossPositivity(LossLagrange):
True
"""

def __init__(self, alpha=1e-7, **kwargs):
def __init__(self, alpha=0.0, **kwargs):
self.alpha = alpha
super().__init__(**kwargs)

Expand Down
70 changes: 63 additions & 7 deletions n3fit/src/n3fit/layers/msr_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,22 @@
from n3fit.backends import MetaLayer
from n3fit.backends import operations as op

IDX = {'photon': 0, 'sigma': 1, 'g': 2, 'v': 3, 'v3': 4, 'v8': 5, 'v15': 6, 'v24': 7, 'v35': 8}
IDX = {
'photon': 0,
'sigma': 1,
'g': 2,
'v': 3,
'v3': 4,
'v8': 5,
'v15': 6,
'v24': 7,
'v35': 8,
't3': 9,
't8': 10,
't15': 11,
't24': 12,
't35': 13,
}
MSR_COMPONENTS = ['g']
MSR_DENOMINATORS = {'g': 'g'}
# The VSR normalization factor of component f is given by
Expand All @@ -20,12 +35,42 @@
VSR_CONSTANTS = {'v': 3.0, 'v35': 3.0, 'v24': 3.0, 'v3': 1.0, 'v8': 3.0, 'v15': 3.0}
VSR_DENOMINATORS = {'v': 'v', 'v35': 'v', 'v24': 'v', 'v3': 'v3', 'v8': 'v8', 'v15': 'v15'}

CSR_COMPONENTS = ['v','v35','v24']
CSR_COMPONENTS = ['v', 'v35', 'v24']
CSR_DENOMINATORS = {'v': 'v', 'v35': 'v', 'v24': 'v'}
NOV15_COMPONENTS = ['v3', 'v8']
NOV15_CONSTANTS = {'v3': 1.0, 'v8': 3.0}
NOV15_DENOMINATORS = {'v3': 'v3', 'v8': 'v8'}

# The following lays out the SR for Polarised PDFs
TSR_COMPONENTS = ['t3', 't8']
TSR_DENOMINATORS = {'t3': 't3', 't8': 't8'}
# Sum Rules defined as in PDG 2023
TSR_CONSTANTS = {'t3': 1.2756, 't8': 0.5850}
TSR_CONSTANTS_UNC = {'t3': 0.0013, 't8': 0.025}


def sample_tsr(v: dict, e: dict, t: list, nr: int) -> list:
"""
Sample the Triplets sum rules according to the PDG uncertainties.
Parameters
----------
v: dict
dictionary that maps the triplet component to its PDG value
e: dict
dictionary that maps the triplet component to its denominator
t: list
list of triplet component for which SR should be applied
nr: int
number of replicas that are fitter simultaneously
Returns
-------
list:
list of sum rule values sampled according to a normal distribution
"""
return [[np.random.normal(v[c], e[c]) for _ in range(nr)] for c in t]


class MSR_Normalization(MetaLayer):
"""
Expand All @@ -34,7 +79,8 @@ class MSR_Normalization(MetaLayer):

_msr_enabled = False
_vsr_enabled = False
_csr_enabled = False
_tsr_enabled = False
_csr_enabled = False

def __init__(self, mode: str = "ALL", replicas: int = 1, **kwargs):
if mode == True or mode.upper() == "ALL":
Expand All @@ -44,7 +90,8 @@ def __init__(self, mode: str = "ALL", replicas: int = 1, **kwargs):
self._msr_enabled = True
elif mode.upper() == "VSR":
self._vsr_enabled = True

elif mode.upper() == "TSR":
self._tsr_enabled = True
elif mode.upper() == "ALLBUTCSR":
self._msr_enabled = True
self._csr_enabled = True
Expand All @@ -63,6 +110,13 @@ def __init__(self, mode: str = "ALL", replicas: int = 1, **kwargs):
self.vsr_factors = op.numpy_to_tensor(
[np.repeat(VSR_CONSTANTS[c], replicas) for c in VSR_COMPONENTS]
)
if self._tsr_enabled:
self.divisor_indices += [IDX[TSR_DENOMINATORS[c]] for c in TSR_COMPONENTS]
indices += [IDX[c] for c in TSR_COMPONENTS]
self.tsr_factors = op.numpy_to_tensor(
sample_tsr(TSR_CONSTANTS, TSR_CONSTANTS_UNC, TSR_COMPONENTS, replicas)
)

if self._csr_enabled:
# modified vsr for V, V24, V35
indices += [IDX[c] for c in CSR_COMPONENTS]
Expand Down Expand Up @@ -114,11 +168,13 @@ def call(self, pdf_integrated, photon_integral):
]
if self._vsr_enabled:
numerators += [self.vsr_factors]
if self._tsr_enabled:
numerators += [self.tsr_factors]
if self._csr_enabled:
numerators += len(CSR_COMPONENTS)*[op.batchit(4.0 - 1./3. * y[IDX['v15']], batch_dimension=0)]
numerators += len(CSR_COMPONENTS) * [
op.batchit(4.0 - 1.0 / 3.0 * y[IDX['v15']], batch_dimension=0)
]
numerators += [self.vsr_factors]



numerators = op.concatenate(numerators, axis=0)
divisors = op.gather(y, self.divisor_indices, axis=0)
Expand Down
32 changes: 31 additions & 1 deletion n3fit/src/n3fit/layers/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,27 @@ class Observable(MetaLayer, ABC):
number of flavours in the pdf (default:14)
"""

def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs):
def __init__(
self,
fktable_data,
fktable_arr,
dataset_name,
fitbasis,
extern_lhapdf,
operation_name,
n_replicas=1,
nfl=14,
**kwargs
):
super(MetaLayer, self).__init__(**kwargs)

self.dataset_name = dataset_name
self.nfl = nfl
self.fitbasis = fitbasis
self.fktable_data = fktable_data
self.nfks = len(fktable_data)

self.computed_pdfs = []
basis = []
xgrids = []
self.fktables = []
Expand All @@ -51,6 +67,14 @@ def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs):
basis.append(fkdata.luminosity_mapping)
self.fktables.append(op.numpy_to_tensor(fk))

if self.is_polarised_pos():
resx = extern_lhapdf(fkdata.xgrid.tolist())
mult_resx = np.repeat([resx], n_replicas, axis=0)
resx = np.expand_dims(mult_resx, axis=0)
self.computed_pdfs.append(op.numpy_to_tensor(resx))
# TODO: Ideally fetch info from Commondata Metadata
operation_name = "SMP" if self.nfks == 4 else "ADD"

# check how many xgrids this dataset needs
if is_unique(xgrids):
self.splitting = None
Expand All @@ -71,6 +95,12 @@ def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs):
def compute_output_shape(self, input_shape):
return (self.output_dim, None)

def is_polarised_pos(self):
if "POL" in self.fitbasis and "_POS_" in self.dataset_name:
# Polarised POS contains at least 2 FK tables
return self.nfks >= 2
return False

# Overridables
@abstractmethod
def gen_mask(self, basis):
Expand Down
21 changes: 16 additions & 5 deletions n3fit/src/n3fit/layers/x_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,34 @@ class xDivide(MetaLayer):
Create tensor of either 1/x or ones depending on the flavour,
to be used to divide some PDFs by x by multiplying with the result.
By default it utilizes the 14-flavour FK basis and divides [v, v3, v8, v15]
which corresponds to indices (3, 4, 5, 6) from
By default it utilizes the 14-flavour FK basis. In the unpolarized
case, one divides [v, v3, v8, v15] which corresponds to indices
(3, 4, 5, 6) from the FK basis:
(photon, sigma, g, v, v3, v8, v15, v24, v35, t3, t8, t15, t24, t35)
In the polarized case, only [T3, T8] are divided by `x` which
corresponds to the indices (9, 10).
Parameters:
-----------
output_dim: int
dimension of the pdf
div_list: list
list of indices to be divided by X (by default [3, 4, 5, 6]; [v, v3, v8, v15]
list of indices to be divided by `x`
"""

def __init__(
self, output_dim: int = BASIS_SIZE, div_list: Optional[List[int]] = None, **kwargs
self,
output_dim: int = BASIS_SIZE,
fitbasis: str = "NN31IC",
div_list: Optional[List[int]] = None,
**kwargs
):
if div_list is None:
if div_list is None: # Default Unpolarized Case
div_list = [3, 4, 5, 6]
div_list = [9, 10] if "POL" in fitbasis else div_list

self.output_dim = output_dim
self.div_list = div_list
super().__init__(**kwargs)
Expand Down
Loading

0 comments on commit 3eedbb7

Please sign in to comment.