From fbfb039e7216dcefa8327bf79c272fa73eaad0ac Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Thu, 5 Dec 2024 12:02:34 +0100 Subject: [PATCH] rebinning (WIP) --- hbw/config/hist_hooks.py | 213 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 198 insertions(+), 15 deletions(-) diff --git a/hbw/config/hist_hooks.py b/hbw/config/hist_hooks.py index 8518374d..3083e54e 100644 --- a/hbw/config/hist_hooks.py +++ b/hbw/config/hist_hooks.py @@ -6,6 +6,7 @@ from __future__ import annotations +import math import law import order as od @@ -17,7 +18,15 @@ logger = law.logger.get_logger(__name__) -def rebin_hist(h, axis_name, edges): +def apply_rebinning_edges(h: hist.Histogram, axis_name: str, edges: list): + """ + Generalized rebinning of a single axis from a hist.Histogram, using predefined edges. + + :param h: histogram to rebin + :param axis_name: string representing the axis to rebin + :param edges: list of floats representing the new bin edges. Must be a subset of the original edges. + :return: rebinned histogram + """ if isinstance(edges, int): return h[{axis_name: hist.rebin(edges)}] @@ -64,18 +73,190 @@ def rebin_hist(h, axis_name, edges): return hnew +def merge_hists_per_config( + task, + hists: dict[str, dict[od.Process, hist.Histogram]], +): + if len(task.config_insts) != 1: + process_memory = {} + merged_hists = {} + for config, _hists in hists.items(): + for process_inst, h in _hists.items(): + + if process_inst.id in merged_hists: + merged_hists[process_inst.id] += h + else: + merged_hists[process_inst.id] = h + process_memory[process_inst.id] = process_inst + + process_insts = list(process_memory.values()) + hists = {process_memory[process_id]: h for process_id, h in merged_hists.items()} + else: + hists = hists[task.config_inst.name] + process_insts = list(hists.keys()) + + return hists, process_insts + + +def apply_rebin_edges_to_all( + hists: dict[str, dict[od.Process, hist.Histogram]], + edges: list[float], + axis_name: str, +) -> dict[str, dict[od.Process, hist.Histogram]]: + """ + Apply rebin edges to histograms for all configs and processes. + """ + h_out = {} + for config, _hists in hists.items(): + h_rebinned = DotDict() + for proc, h in _hists.items(): + old_axis = h.axes[axis_name] + h_rebin = apply_rebinning_edges(h.copy(), old_axis.name, edges) + + if not np.isclose(h.sum().value, h_rebin.sum().value): + raise Exception(f"Rebinning changed histogram value: {h.sum().value} -> {h_rebin.sum().value}") + if not np.isclose(h.sum().variance, h_rebin.sum().variance): + raise Exception(f"Rebinning changed histogram variance: {h.sum().variance} -> {h_rebin.sum().variance}") + h_rebinned[proc] = h_rebin + + h_out[config] = h_rebinned + + return h_out + + +def select_category_and_shift( + task, + h: hist.Histogram, +): + # get the shifts to extract and plot + plot_shifts = law.util.make_list(task.get_plot_shifts()) + + category_inst = task.config_inst.get_category(task.branch_data.category) + leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] + + # selections + h = h[{ + "category": [ + hist.loc(c.id) + for c in leaf_category_insts + if c.id in h.axes["category"] + ], + "shift": [ + hist.loc(s.id) + for s in plot_shifts + if s.id in h.axes["shift"] + ], + }] + # reductions + h = h[{"category": sum, "shift": sum}] + + return h + + def add_hist_hooks(config: od.Config) -> None: """ Add histogram hooks to a configuration. """ - def rebin(task, hists: hist.Histogram): + from hbw.util import timeit_multiple + @timeit_multiple + def rebin(task, hists: dict[str, dict[od.Process, hist.Histogram]]) -> dict[str, hist.Histogram]: """ Rebin histograms with edges that are pre-defined for a certain variable and category. Lots of hard-coded stuff at the moment. """ + logger.info("Rebinning histograms") + + # category_inst = task.config_inst.get_category(task.branch_data.category) + # get variable inst assuming we created a 1D histogram variable_inst = task.config_inst.get_variable(task.branch_data.variable) + variable_inst.x.rebin = None + rebin_config = variable_inst.x("rebin_config", None) + if rebin_config is None: + logger.info("No rebinning configuration found, skipping rebinning") + return hists + + # merge histograms over all configs + hists_per_process, hist_process_insts = merge_hists_per_config(task, hists) + + # # get process instances for rebinning (sub procs pls) + # rebin_process_insts = [task.config_inst.get_process(proc) for proc in rebin_config["processes"]] + # rebin_sub_process_insts = { + # process_inst.name: [ + # sub + # for sub, _, _ in process_inst.walk_processes(include_self=True) + # if sub.id in [p.id for p in hist_process_insts] + # ] + # for process_inst in rebin_process_insts + # } + + # rebin_process_insts = [ + # process_inst.name for p in process_insts if p.name in rebin_config["processes"]] + + if missing_procs := set(rebin_config["processes"]) - set([p.name for p in hists_per_process]): + raise ValueError( + f"Processes {missing_procs} not found in histograms. For rebinning, the process names " + "requested in plotting/datacards need to match the processes required for rebinning." + ) + + # get histograms used for rebinning by merging over rebin processes defined by variable inst + # work on a copy to not modify original hist + rebin_hist = sum([ + h for proc_inst, h in hists_per_process.items() + if proc_inst.name in rebin_config["processes"] + ]).copy() + + # select and reduce category and shift axis + rebin_hist = select_category_and_shift(task, rebin_hist) + + # the effective number of events should be larger than a certain number + # error_criterium = lambda value, variance: value ** 2 / variance > rebin.min_entries + # equal_width_criterium = lambda value, n_bins, integral: value > integral / n_bins + + edges = [] + + def get_bin_edges_simple( + h, + n_bins, + reversed_order: bool = False, + ): + requested_cumsum = h.sum().value / n_bins + h_copy = h.copy() + + cumsum_value = np.cumsum(h_copy.values()[::-1])[::-1] if reversed_order else np.cumsum(h_copy.values()) + # cumsum_variance = np.cumsum(h_copy.variances()[::-1])[::-1] if reversed_order else np.cumsum(h_copy.variances()) + + current_bin_edge = np.astype(cumsum_value / requested_cumsum, int) + + diffs = np.diff(current_bin_edge) + indices = np.where(diffs > 0)[0] + + edges = [0.] + list(h_copy.axes[0].edges[indices]) + edges[-1] = 1.0 + + return edges + + edges = get_bin_edges_simple(rebin_hist, rebin_config.get("n_bins", 4)) + print(edges) + + h_out = {} + h_out = apply_rebin_edges_to_all(hists, edges, variable_inst.name) + return h_out + + # rebin default parameters + rebin.default_n_bins = 10 + + def rebin_example(task, hists: dict[str, dict[od.Process, hist.Histogram]]) -> dict[str, hist.Histogram]: + """ + Rebin histograms with edges that are pre-defined for a certain variable and category. + Lots of hard-coded stuff at the moment. + """ + logger.info("Rebinning histograms") + + # get variable inst assuming we created a 1D histogram + variable_inst = task.config_inst.get_variable(task.branch_data.variable) + variable_inst.x.rebin = None # edges for 2b channel edges = { @@ -85,23 +266,25 @@ def rebin(task, hists: hist.Histogram): "mlscore.h": [0.0, 0.494, 0.651, 1.0], } - h_rebinned = DotDict() - + h_out = {} edges = edges[variable_inst.name] - for proc, h in hists.items(): - old_axis = h.axes[variable_inst.name] - - h_rebin = rebin_hist(h.copy(), old_axis.name, edges) - - if not np.isclose(h.sum().value, h_rebin.sum().value): - raise Exception(f"Rebinning changed histogram value: {h.sum().value} -> {h_rebin.sum().value}") - if not np.isclose(h.sum().variance, h_rebin.sum().variance): - raise Exception(f"Rebinning changed histogram variance: {h.sum().variance} -> {h_rebin.sum().variance}") - h_rebinned[proc] = h_rebin + h_out = apply_rebin_edges_to_all(hists, edges, variable_inst.name) + return h_out - return h_rebinned # add hist hooks to config config.x.hist_hooks = { + "rebin_example": rebin_example, "rebin": rebin, } + + +def rebinning(N=10): + # pseudo code + get_final_bin_width = rebin_rest = None + check_final_bin = True + while check_final_bin: + check_final_bin, bin_width = get_final_bin_width(N) + N = N - 1 + + rebin_rest(N)