Skip to content

Commit

Permalink
Adjust res dnn producers for control categories.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Jan 30, 2025
1 parent 041033f commit b9b983e
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions hbt/production/res_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
"Electron.{eta,phi,pt,mass,charge}",
"Muon.{eta,phi,pt,mass,charge}",
"HHBJet.{pt,eta,phi,mass,hhbtag,btagDeepFlav*,btagPNet*}",
"MET.{pt,phi,covXX,covXY,covYY}",
"FatJet.{eta,phi,pt,mass}",
# MET variables added in dynamic init,
},
# whether the model is parameterized in mass, spin and year
# (this is a slight forward declaration but simplifies the code reasonably well in our use case)
Expand All @@ -65,7 +65,7 @@ def _res_dnn_evaluation(
correct order can be found in the tautauNN repo:
https://github.com/uhh-cms/tautauNN/blob/f1ca194/evaluation/interface.py#L67
"""
tf = maybe_import("tensorflow")
import tensorflow as tf

# ensure coffea behavior
events = self[attach_coffea_behavior](
Expand All @@ -89,12 +89,19 @@ def _res_dnn_evaluation(
# get decay mode of first lepton (e, mu or tau)
tautau_mask = events.channel_id == self.config_inst.channels.n.tautau.id
dm1 = -1 * np.ones(len(events), dtype=np.int32)
dm1[tautau_mask] = events.Tau.decayMode[tautau_mask][:, 0]
if ak.any(tautau_mask):
dm1[tautau_mask] = events.Tau.decayMode[tautau_mask][:, 0]

# get decay mode of second lepton (also a tau, but position depends on channel)
dm2 = np.zeros(len(events), dtype=np.int32)
dm2[~tautau_mask] = events.Tau.decayMode[~tautau_mask][:, 0]
dm2[tautau_mask] = events.Tau.decayMode[tautau_mask][:, 1]
leptau_mask = (
(events.channel_id == self.config_inst.channels.n.etau.id) |
(events.channel_id == self.config_inst.channels.n.mutau.id)
)
dm2 = -1 * np.ones(len(events), dtype=np.int32)
if ak.any(leptau_mask):
dm2[leptau_mask] = events.Tau.decayMode[leptau_mask][:, 0]
if ak.any(tautau_mask):
dm2[tautau_mask] = events.Tau.decayMode[tautau_mask][:, 1]

# the dnn treats dm 2 as 1, so we need to map it
dm1 = np.where(dm1 == 2, 1, dm1)
Expand Down Expand Up @@ -204,12 +211,13 @@ def mask_values(mask, value, *fields):
mask_values(~has_fatjet, 0.0, "httfatjet_e", "httfatjet_px", "httfatjet_py", "httfatjet_pz")

# MET variables
_met = _events[self.config_inst.x.met_name]
f.met_px, f.met_py = rotate_to_phi(
phi_lep,
_events.MET.pt * np.cos(_events.MET.phi),
_events.MET.pt * np.sin(_events.MET.phi),
_met.pt * np.cos(_met.phi),
_met.pt * np.sin(_met.phi),
)
f.met_cov00, f.met_cov01, f.met_cov11 = _events.MET.covXX, _events.MET.covXY, _events.MET.covYY
f.met_cov00, f.met_cov01, f.met_cov11 = _met.covXX, _met.covXY, _met.covYY

# build continous inputs
# (order exactly as documented in link above)
Expand Down Expand Up @@ -268,6 +276,11 @@ def mask_values(mask, value, *fields):
return events


@_res_dnn_evaluation.init
def _res_dnn_evaluation_init(self: Producer) -> None:
self.uses.add(f"{self.config_inst.x.met_name}.{{pt,phi,covXX,covXY,covYY}}")


@_res_dnn_evaluation.requires
def _res_dnn_evaluation_requires(self: Producer, reqs: dict) -> None:
if "external_files" in reqs:
Expand All @@ -284,12 +297,18 @@ def _res_dnn_evaluation_setup(
inputs: dict,
reader_targets: InsertableDict,
) -> None:
tf = maybe_import("tensorflow")
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf

# some checks
if not isinstance(self.parametrized, bool):
raise AttributeError("'parametrized' must be set in the producer configuration")

# constrain tf to use only one core
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.threading.set_intra_op_parallelism_threads(1)

# unpack the model archive
bundle = reqs["external_files"]
bundle.files
Expand Down Expand Up @@ -317,9 +336,14 @@ def _res_dnn_evaluation_setup(

# our channel ids mapped to KLUB "pair_type"
self.channel_id_to_pair_type = {
# known during training
self.config_inst.channels.n.mutau.id: 0,
self.config_inst.channels.n.etau.id: 1,
self.config_inst.channels.n.tautau.id: 2,
# unknown during training
self.config_inst.channels.n.ee.id: 1,
self.config_inst.channels.n.mumu.id: 0,
self.config_inst.channels.n.emu.id: 1,
}

# define the year based on the incoming campaign
Expand Down Expand Up @@ -351,6 +375,8 @@ def _res_dnn_evaluation_setup(

@res_pdnn.init
def res_pdnn_init(self: Producer) -> None:
super(res_pdnn, self).init_func()

# check spin value and mass values
if self.spin not in {0, 2}:
raise ValueError(f"invalid spin value: {self.spin}")
Expand Down Expand Up @@ -380,6 +406,8 @@ def res_pdnn_init(self: Producer) -> None:

@res_dnn.init
def res_dnn_init(self: Producer) -> None:
super(res_dnn, self).init_func()

# output column names (in this order)
self.output_columns = [
f"res_dnn_{name}"
Expand Down

0 comments on commit b9b983e

Please sign in to comment.