Skip to content

Commit

Permalink
Port stitching updates from #55.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Jan 13, 2025
1 parent 3895d97 commit f31f769
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 110 deletions.
16 changes: 15 additions & 1 deletion hbt/config/configs_hbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,21 @@ def if_not_era(*, values: list[str | None] | None = None, **kwargs) -> list[str]
},
}
# w+jets
# TODO: add
cfg.x.w_lnu_stitching = {
"incl": {
"inclusive_dataset": cfg.datasets.n.w_lnu_amcatnlo,
"leaf_processes": [
# the following processes cover the full njet and pt phasespace
procs.n.w_lnu_0j,
*(
procs.get(f"w_lnu_{nj}j_pt{pt}")
for nj in [1, 2]
for pt in ["0to40", "40to100", "100to200", "200to400", "400to600", "600toinf"]
),
procs.n.w_lnu_ge3j,
],
},
}

# dataset groups for conveniently looping over certain datasets
# (used in wrapper_factory and during plotting)
Expand Down
258 changes: 168 additions & 90 deletions hbt/production/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
Process ID producer relevant for the stitching of the DY samples.
"""

import functools
from __future__ import annotations

import abc

import law
import order

from columnflow.production import Producer, producer
from columnflow.production import Producer
from columnflow.util import maybe_import, InsertableDict
from columnflow.columnar_util import set_ak_column
from columnflow.columnar_util import set_ak_column, Route

from hbt.util import IF_DATASET_IS_DY
from hbt.util import IF_DATASET_IS_DY, IF_DATASET_IS_W_LNU

np = maybe_import("numpy")
ak = maybe_import("awkward")
Expand All @@ -25,88 +28,160 @@
NJetsRange = tuple[int, int]
PtRange = tuple[float, float]

set_ak_column_i64 = functools.partial(set_ak_column, value_type=np.int64)

class stitched_process_ids(Producer):
"""General class to calculate process ids for stitched samples.
@producer(
uses={IF_DATASET_IS_DY("LHE.NpNLO", "LHE.Vpt")},
produces={IF_DATASET_IS_DY("process_id")},
)
def process_ids_dy(self: Producer, events: ak.Array, **kwargs) -> ak.Array:
"""
Assigns each dy event a single process id, based on the number of jets and the di-lepton pt of
the LHE record. This is used for the stitching of the DY samples.
Individual producers should derive from this class and set the following attributes:
:param id_table: scipy lookup table mapping processes variables (using key_func) to process ids
:param key_func: function to generate keys for the lookup, receiving values of stitching columns
:param stitching_columns: list of observables to use for stitching
:param cross_check_translation_dict: dictionary to translate stitching columns to auxiliary
fields of process objects, used for cross checking the validity of obtained ranges
:param include_condition: condition for including stitching columns in used columns
"""
# as always, we assume that each dataset has exactly one process associated to it
if len(self.dataset_inst.processes) != 1:
raise NotImplementedError(
f"dataset {self.dataset_inst.name} has {len(self.dataset_inst.processes)} processes "
"assigned, which is not yet implemented",
)
process_inst = self.dataset_inst.processes.get_first()

# get the number of nlo jets and the di-lepton pt
njets = events.LHE.NpNLO
pt = events.LHE.Vpt

# raise a warning if a datasets was already created for a specific "bin" (leaf process),
# but actually does not fit
njets_range = process_inst.x("njets", None)
if njets_range is not None:
outliers = (njets < njets_range[0]) | (njets >= njets_range[1])
if ak.any(outliers):
logger.warning(
f"dataset {self.dataset_inst.name} is meant to contain njet values in the range "
f"[{njets_range[0]}, {njets_range[0]}), but found {ak.sum(outliers)} events "
"outside this range",

@abc.abstractproperty
def id_table(self) -> sp.sparse._lil.lil_matrix:
# must be overwritten by inheriting classes
...

@abc.abstractmethod
def key_func(self, *values: ak.Array) -> int:
# must be overwritten by inheriting classes
...

@abc.abstractproperty
def cross_check_translation_dict(self) -> dict[str, str]:
# must be overwritten by inheriting classes
...

def init_func(self, *args, **kwargs):
# if there is a include_condition set, apply it to both used and produced columns
cond = lambda args: {self.include_condition(*args)} if self.include_condition else {*args}
self.uses |= cond(self.stitching_columns or [])
self.produces |= cond(["process_id"])

def call_func(self, events: ak.Array, **kwargs) -> ak.Array:
"""
Assigns each event a single process id, based on the stitching values extracted per event.
This id can be used for the stitching of the respective datasets downstream.
"""
# ensure that each dataset has exactly one process associated to it
if len(self.dataset_inst.processes) != 1:
raise NotImplementedError(
f"dataset {self.dataset_inst.name} has {len(self.dataset_inst.processes)} processes "
"assigned, which is not yet implemented",
)
pt_range = process_inst.x("ptll", None)
if pt_range is not None:
outliers = (pt < pt_range[0]) | (pt >= pt_range[1])
if ak.any(outliers):
logger.warning(
f"dataset {self.dataset_inst.name} is meant to contain ptll values in the range "
f"[{pt_range[0]}, {pt_range[1]}), but found {ak.sum(outliers)} events outside this "
"range",
process_inst = self.dataset_inst.processes.get_first()

# get stitching observables
stitching_values = [Route(obs).apply(events) for obs in self.stitching_columns]

# run the cross check function if defined
if callable(self.stitching_range_cross_check):
self.stitching_range_cross_check(process_inst, stitching_values)

# lookup the id and check for invalid values
process_ids = np.squeeze(np.asarray(self.id_table[self.key_func(*stitching_values)].todense()))
invalid_mask = process_ids == 0
if ak.any(invalid_mask):
raise ValueError(
f"found {sum(invalid_mask)} events that could not be assigned to a process",
)

# lookup the id and check for invalid values
process_ids = np.squeeze(np.asarray(self.id_table[self.key_func(njets, pt)].todense()))
invalid_mask = process_ids == 0
if ak.any(invalid_mask):
raise ValueError(
f"found {sum(invalid_mask)} dy events that could not be assigned to a process",
)

# store them
events = set_ak_column_i64(events, "process_id", process_ids)

return events


@process_ids_dy.setup
def process_ids_dy_setup(
self: Producer,
reqs: dict,
inputs: dict,
reader_targets: InsertableDict,
) -> None:
# define stitching ranges for the DY datasets covered by this producer's dy_inclusive_dataset
stitching_ranges: dict[NJetsRange, list[PtRange]] = {}
for proc in self.dy_leaf_processes:
njets = proc.x.njets
stitching_ranges.setdefault(njets, [])
if proc.has_aux("ptll"):
stitching_ranges[njets].append(proc.x.ptll)

# sort by the first element of the ptll range
sorted_stitching_ranges: list[tuple[NJetsRange, list[PtRange]]] = [
(nj_range, sorted(stitching_ranges[nj_range], key=lambda ptll_range: ptll_range[0]))
for nj_range in sorted(stitching_ranges.keys(), key=lambda nj_range: nj_range[0])
]

# define a key function that maps njets and pt to a unique key for use in a lookup table
def key_func(njets, pt):
# store them
events = set_ak_column(events, "process_id", process_ids, value_type=np.int64)

return events

def stitching_range_cross_check(
self: Producer,
process_inst: order.Process,
stitching_values: list[ak.Array],
) -> None:
# define lookup for stitching observable -> process auxiliary values to compare with
# raise a warning if a datasets was already created for a specific "bin" (leaf process),
# but actually does not fit
for column, values in zip(self.stitching_columns, stitching_values):
aux_name = self.cross_check_translation_dict[str(column)]
if not process_inst.has_aux(aux_name):
continue
aux_min, aux_max = process_inst.x(aux_name)
outliers = (values < aux_min) | (values >= aux_max)
if ak.any(outliers):
logger.warning(
f"dataset {self.dataset_inst.name} is meant to contain {aux_name} values in "
f"the range [{aux_min}, {aux_max}), but found {ak.sum(outliers)} events "
"outside this range",
)


class stiched_process_ids_nj_pt(stitched_process_ids):
"""
Process identifier for subprocesses spanned by a jet multiplicity and an optional pt range, such
as DY or W->lnu, which have (e.g.) "*_1j" as well as "*_1j_pt100to200" subprocesses.
"""

# id table is set during setup, create a non-abstract class member in the meantime
id_table = None

# required aux fields
njets_aux = "njets"
pt_aux = "ptll"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# setup during setup
self.sorted_stitching_ranges: list[tuple[NJetsRange, list[PtRange]]]

# check that aux fields are present in cross_check_translation_dict
for field in (self.njets_aux, self.pt_aux):
if field not in self.cross_check_translation_dict.values():
raise ValueError(f"field {field} must be present in cross_check_translation_dict")

@abc.abstractproperty
def leaf_processes(self) -> list[order.Process]:
# must be overwritten by inheriting classes
...

def setup_func(
self,
reqs: dict,
inputs: dict,
reader_targets: InsertableDict,
) -> None:
# define stitching ranges for the DY datasets covered by this producer's dy_inclusive_dataset
stitching_ranges: dict[NJetsRange, list[PtRange]] = {}
for proc in self.leaf_processes:
njets = proc.x(self.njets_aux)
stitching_ranges.setdefault(njets, [])
if proc.has_aux(self.pt_aux):
stitching_ranges[njets].append(proc.x(self.pt_aux))

# sort by the first element of the ptll range
self.sorted_stitching_ranges = [
(nj_range, sorted(stitching_ranges[nj_range], key=lambda ptll_range: ptll_range[0]))
for nj_range in sorted(stitching_ranges.keys(), key=lambda nj_range: nj_range[0])
]

# define the lookup table
max_nj_bin = len(self.sorted_stitching_ranges)
max_pt_bin = max(map(len, stitching_ranges.values()))
self.id_table = sp.sparse.lil_matrix((max_nj_bin + 1, max_pt_bin + 1), dtype=np.int64)

# fill it
for proc in self.leaf_processes:
key = self.key_func(proc.x(self.njets_aux)[0], proc.x(self.pt_aux, [-1])[0])
self.id_table[key] = proc.id

def key_func(
self,
njets: int | np.ndarray,
pt: int | float | np.ndarray,
) -> tuple[int, int] | tuple[np.ndarray, np.ndarray]:
# potentially convert single values into arrays
single = False
if isinstance(njets, int):
Expand All @@ -118,7 +193,7 @@ def key_func(njets, pt):
# map into bins (index 0 means no binning)
nj_bins = np.zeros(len(njets), dtype=np.int32)
pt_bins = np.zeros(len(pt), dtype=np.int32)
for nj_bin, (nj_range, pt_ranges) in enumerate(sorted_stitching_ranges, 1):
for nj_bin, (nj_range, pt_ranges) in enumerate(self.sorted_stitching_ranges, 1):
# nj_bin
nj_mask = (nj_range[0] <= njets) & (njets < nj_range[1])
nj_bins[nj_mask] = nj_bin
Expand All @@ -129,14 +204,17 @@ def key_func(njets, pt):

return (nj_bins[0], pt_bins[0]) if single else (nj_bins, pt_bins)

self.key_func = key_func

# define the lookup table
max_nj_bin = len(sorted_stitching_ranges)
max_pt_bin = max(map(len, stitching_ranges.values()))
self.id_table = sp.sparse.lil_matrix((max_nj_bin + 1, max_pt_bin + 1), dtype=np.int64)
process_ids_dy = stiched_process_ids_nj_pt.derive("process_ids_dy", cls_dict={
"stitching_columns": ["LHE.NpNLO", "LHE.Vpt"],
"cross_check_translation_dict": {"LHE.NpNLO": "njets", "LHE.Vpt": "ptll"},
"include_condition": IF_DATASET_IS_DY,
# still misses leaf_processes, must be set dynamically
})

# fill it
for proc in self.dy_leaf_processes:
key = key_func(proc.x.njets[0], proc.x("ptll", [-1])[0])
self.id_table[key] = proc.id
process_ids_w_lnu = stiched_process_ids_nj_pt.derive("process_ids_w_lnu", cls_dict={
"stitching_columns": ["LHE.NpNLO", "LHE.Vpt"],
"cross_check_translation_dict": {"LHE.NpNLO": "njets", "LHE.Vpt": "ptll"},
"include_condition": IF_DATASET_IS_W_LNU,
# still misses leaf_processes, must be set dynamically
})
51 changes: 32 additions & 19 deletions hbt/selection/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from hbt.selection.trigger import trigger_selection
from hbt.selection.lepton import lepton_selection
from hbt.selection.jet import jet_selection
from hbt.production.processes import process_ids_dy
import hbt.production.processes as process_producers
from hbt.production.btag import btag_weights_deepjet, btag_weights_pnet
from hbt.production.features import cutflow_features
from hbt.production.patches import patch_ecalBadCalibFilter
Expand Down Expand Up @@ -148,6 +148,8 @@ def default(
# create process ids
if self.process_ids_dy is not None:
events = self[self.process_ids_dy](events, **kwargs)
elif self.process_ids_w_lnu is not None:
events = self[self.process_ids_w_lnu](events, **kwargs)
else:
events = self[process_ids](events, **kwargs)

Expand Down Expand Up @@ -189,24 +191,33 @@ def default_init(self: Selector) -> None:
if getattr(self, "dataset_inst", None) is None:
return

self.process_ids_dy: process_ids_dy | None = None
if self.dataset_inst.has_tag("dy"):
# check if this dataset is covered by any dy id producer
for name, dy_cfg in self.config_inst.x.dy_stitching.items():
dataset_inst = dy_cfg["inclusive_dataset"]
# the dataset is "covered" if its process is a subprocess of that of the dy dataset
if dataset_inst.has_process(self.dataset_inst.processes.get_first()):
self.process_ids_dy = process_ids_dy.derive(f"process_ids_dy_{name}", cls_dict={
"dy_inclusive_dataset": dataset_inst,
"dy_leaf_processes": dy_cfg["leaf_processes"],
})

# add it as a dependency
self.uses.add(self.process_ids_dy)
self.produces.add(self.process_ids_dy)

# stop after the first match
break
# build and store derived process id producers
for tag in ("dy", "w_lnu"):
prod_name = f"process_ids_{tag}"
setattr(self, prod_name, None)
if not self.dataset_inst.has_tag(tag):
continue
# check if the producer was already created and saved in the config
if (prod := self.config_inst.x(prod_name, None)) is None:
# check if this dataset is covered by any dy id producer
for stitch_name, cfg in self.config_inst.x(f"{tag}_stitching").items():
incl_dataset_inst = cfg["inclusive_dataset"]
# the dataset is "covered" if its process is a subprocess of that of the dy dataset
if incl_dataset_inst.has_process(self.dataset_inst.processes.get_first()):
base_prod = getattr(process_producers, prod_name)
prod = base_prod.derive(f"{prod_name}_{stitch_name}", cls_dict={
"leaf_processes": cfg["leaf_processes"],
})
# cache it
self.config_inst.set_aux(prod_name, prod)
# stop after the first match
break
if prod is not None:
# add it as a dependency
self.uses.add(prod)
self.produces.add(prod)
# save it as an attribute
setattr(self, prod_name, prod)


empty = default.derive("empty", cls_dict={})
Expand Down Expand Up @@ -288,6 +299,8 @@ def empty_call(
# create process ids
if self.process_ids_dy is not None:
events = self[self.process_ids_dy](events, **kwargs)
elif self.process_ids_w_lnu is not None:
events = self[self.process_ids_w_lnu](events, **kwargs)
else:
events = self[process_ids](events, **kwargs)

Expand Down
Loading

0 comments on commit f31f769

Please sign in to comment.