Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Mathis Frahm <[email protected]>
  • Loading branch information
Lara813 and mafrahm authored Jan 24, 2025
1 parent c38c458 commit 69460b1
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 27 deletions.
3 changes: 1 addition & 2 deletions hbw/plotting/plot_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def scalable_exponnorm(x, A, loc, scale, K=1):


def plot_fit(
hists: dict[str, OrderedDict[od.Process, hist.Hist]],
# hists: OrderedDict[od.Process, hist.Hist],
hists: OrderedDict[od.Process, hist.Hist],
config_inst: od.Config,
category_inst: od.Category,
variable_insts: list[od.Variable],
Expand Down
15 changes: 6 additions & 9 deletions hbw/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@


# Function copied from Mathis Hist hook commit
# TODO: define once at central place (hist_util.py)
def apply_rebinning_edges(h: hist.Histogram, axis_name: str, edges: list):
"""
Generalized rebinning of a single axis from a hist.Histogram, using predefined edges.
Expand Down Expand Up @@ -322,15 +323,11 @@ def get_rebin_processes(self):

rebin_process_condition = self.inference_category_rebin_processes.get(config_category, None)
if not rebin_process_condition:
if "ggf" in config_category:
for proc in processes.copy():
proc_name = proc.config_process

logger.warning(
f"No rebin condition found for category {config_category}; rebinning will be flat "
f"on all processes {[proc.config_process for proc in processes]}",
)
return processes
logger.warning(
f"No rebin condition found for category {config_category}; rebinning will be flat "
f"on all processes {[proc.config_process for proc in processes]}",
)
return processes

# transform `rebin_process_condition` into Callable if required
if not isinstance(rebin_process_condition, Callable):
Expand Down
28 changes: 12 additions & 16 deletions hbw/tasks/postfit_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@ def load_hists_uproot(fit_diagnostics_path, fit_type):
""" Helper to load histograms from a fit_diagnostics file """
with uproot.open(fit_diagnostics_path) as tfile:
if any("shapes_fit_s" in _k for _k in tfile.keys()):
if fit_type != "prefit":
if fit_type == "postfit":
fit_type = "fit_s"
hists = get_hists_from_fit_diagnostics(tfile)[f"shapes_{fit_type}"]
else:
# if fit_type != "prefit":
# fit_type = "postfit"
hists = get_hists_from_multidimfit(tfile)[f"{fit_type}"]

return hists
Expand Down Expand Up @@ -119,7 +117,7 @@ def get_hists_from_multidimfit(tfile):


def plot_postfit_shapes(
h: OrderedDict, # [od.Process, hist.Hist],
hists: OrderedDict[od.Process, hist.Hist],
config_inst: od.Config,
category_inst: od.Category,
variable_insts: list[od.Variable],
Expand All @@ -133,7 +131,7 @@ def plot_postfit_shapes(
**kwargs,
) -> tuple(plt.Figure, tuple(plt.Axes)):
variable_inst = law.util.make_tuple(variable_insts)[0]
hists = apply_process_settings(h.copy(), process_settings)
hists = apply_process_settings(hists, process_settings)
plot_config = prepare_plot_config(
hists,
shape_norm=shape_norm,
Expand Down Expand Up @@ -200,11 +198,9 @@ class PlotPostfitShapes(
@property
def fit_type(self) -> str:
if self.prefit:
fit_type = "prefit"
return "prefit"
else:
fit_type = "postfit"
self._fit_type = fit_type
return self._fit_type
return "postfit"

def requires(self):
return {}
Expand Down Expand Up @@ -232,11 +228,10 @@ def run(self):
hist_map = defaultdict(list)

# First map process inst for plotting to processes of root shapes
for proc_key in list(h_in.keys()):
for proc_key in h_in.keys():
proc_inst = None
# try getting the config process via InferenceModel
if has_category:
# TODO: process customization based on inference process? e.g. scale
inference_process = self.inference_model_inst.get_process(proc_key, channel)
proc_inst = self.config_inst.get_process(inference_process.config_process)
else:
Expand All @@ -249,12 +244,13 @@ def run(self):
plot_proc = [
proc for proc in process_insts if proc.has_process(proc_inst) or proc.name == proc_inst.name
]
if len(plot_proc) != 1:
if len(plot_proc) > 1:
logger.warning(f"{proc_key} was assigned to more then one process insts ({plot_proc}) ")
else:
logger.warning(f"{proc_key} in root file, but won't be plotted.")
if len(plot_proc) > 1:
logger.warning(f"{proc_key} was assigned to ({", ".join([p.name for p in plot_proc])}) but {plot_proc[].name was chosen}")
elif len(plot_proc) == 0:
logger.warning(f"{proc_key} in root file, but won't be plotted.")
continue
plot_proc = plot_proc[0]


if plot_proc[0] not in hist_map:
hist_map[plot_proc[0]] = [proc_key]
Expand Down

0 comments on commit 69460b1

Please sign in to comment.