Skip to content

Commit

Permalink
implement custom WeightProducer
Browse files Browse the repository at this point in the history
  • Loading branch information
mafrahm committed Nov 27, 2023
1 parent b3fd898 commit 789d517
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 3 deletions.
6 changes: 3 additions & 3 deletions hbw/config/config_run2.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ def add_shift_aliases(shift_source: str, aliases: dict[str], selection_dependent
cfg.add_shift(name="mu_trig_sf_down", id=53, type="shape")
add_shift_aliases("mu_sf", {"muon_weight": "muon_weight_{direction}"}, selection_dependent=False)

btag_uncs = [
cfg.x.btag_uncs = [
"hf", "lf", f"hfstats1_{year}", f"hfstats2_{year}",
f"lfstats1_{year}", f"lfstats2_{year}", "cferr1", "cferr2",
]
for i, unc in enumerate(btag_uncs):
for i, unc in enumerate(cfg.x.btag_uncs):
cfg.add_shift(name=f"btag_{unc}_up", id=100 + 2 * i, type="shape")
cfg.add_shift(name=f"btag_{unc}_down", id=101 + 2 * i, type="shape")
add_shift_aliases(
Expand Down Expand Up @@ -447,7 +447,7 @@ def make_jme_filename(jme_aux, sample_type, name, era=None):
# dataset.x.event_weights = {"top_pt_weight": get_shifts("top_pt")}

# NOTE: which to use, njet_btag_weight or btag_weight?
cfg.x.event_weights["normalized_btag_weight"] = get_shifts(*(f"btag_{unc}" for unc in btag_uncs))
cfg.x.event_weights["normalized_btag_weight"] = get_shifts(*(f"btag_{unc}" for unc in cfg.x.btag_uncs))
cfg.x.event_weights["normalized_pu_weight"] = get_shifts("minbias_xs")
cfg.x.event_weights["electron_weight"] = get_shifts("e_sf")
cfg.x.event_weights["muon_weight"] = get_shifts("mu_sf")
Expand Down
90 changes: 90 additions & 0 deletions hbw/weights/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# coding: utf-8

"""
Event weight producer.
"""

import law

from columnflow.util import maybe_import
from columnflow.weight import WeightProducer, weight_producer
from columnflow.config_util import get_shifts_from_sources
from columnflow.columnar_util import Route

np = maybe_import("numpy")
ak = maybe_import("awkward")

logger = law.logger.get_logger(__name__)


@weight_producer(uses={"normalization_weight"}, mc_only=True)
def normalization_only(self: WeightProducer, events: ak.Array, **kwargs) -> ak.Array:
return events.normalization_weight


@weight_producer(mc_only=True)
def no_weights(self: WeightProducer, events: ak.Array, **kwargs) -> ak.Array:
return ak.Array(np.ones(len(events), dtype=np.float32))


@weight_producer(
# both used columns and dependent shifts are defined in init below
weight_columns=None,
# only run on mc
mc_only=True,
)
def default_weight_producer(self: WeightProducer, events: ak.Array, **kwargs) -> ak.Array:
# build the full event weight
weight = ak.Array(np.ones(len(events), dtype=np.float32))
for column in self.weight_columns.keys():
weight = weight * Route(column).apply(events)

return weight


@default_weight_producer.init
def default_weight_producer_init(self: WeightProducer) -> None:
# set the default weight_columns
if not self.weight_columns:
self.weight_columns = {
"normalization_weight": [],
"normalized_pu_weight": ["minbias_xs"],
"muon_weight": ["mu_sf"],
"electron_weight": ["e_sf"],
"normalized_btag_weight": [f"btag_{unc}" for unc in self.config_inst.x("btag_uncs")],
"normalized_murf_envelope_weight": ["murf_envelope"],
"normalized_mur_weight": ["mur"],
"normalized_muf_weight": ["muf"],
"normalized_pdf_weight": ["pdf"],
}

if self.dataset_inst.has_tag("skip_scale"):
# remove dependency towards mur/muf weights
for column in [
"normalized_mur_weight", "normalized_muf_weight", "normalized_murf_envelope_weight",
"mur_weight", "muf_weight", "murf_envelope_weight",
]:
self.weight_columns.pop(column, None)

if self.dataset_inst.has_tag("skip_pdf"):
# remove dependency towards pdf weights
for column in ["pdf_weight", "normalized_pdf_weight"]:
self.weight_columns.pop(column, None)

self.shifts = set()
for weight_column, shift_sources in self.weight_columns.items():
shift_sources = law.util.make_list(shift_sources)
shifts = get_shifts_from_sources(self.config_inst, *shift_sources)
for shift in shifts:
if weight_column not in shift.x("column_aliases").keys():
# make sure that column aliases are implemented
raise Exception(
f"Weight column {weight_column} implements shift {shift}, but does not use it"
f"in 'column_aliases' aux {shift.x('column_aliases')}",
)

# declare shifts that the produced event weight depends on
self.shifts |= set(shifts)

# store column names referring to weights to multiply
self.uses |= self.weight_columns.keys()
1 change: 1 addition & 0 deletions law.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ production_modules: columnflow.production.{categories,processes,pileup,normaliza
calibration_modules: columnflow.calibration.jets, hbw.calibration.default
selection_modules: hbw.selection.{common,sl,dl}
categorization_modules: hbw.selection.categories
weight_production_modules: columnflow.weight.{empty}, hbw.weights.default
ml_modules: hbw.ml.{base,dense_classifier,dl,sl_res}
inference_modules: hbw.inference.{derived,dl,sl_res}

Expand Down

0 comments on commit 789d517

Please sign in to comment.