Skip to content

Commit

Permalink
Add parse_selection_fields utility and enhance histogram task with se…
Browse files Browse the repository at this point in the history
…lection handling
  • Loading branch information
haddadanas committed Nov 6, 2024
1 parent d67acd2 commit e725338
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 5 deletions.
31 changes: 27 additions & 4 deletions columnflow/tasks/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from __future__ import annotations

from functools import reduce
from operator import and_

import luigi
import law

Expand All @@ -19,7 +22,7 @@
from columnflow.tasks.reduction import ReducedEventsUser
from columnflow.tasks.production import ProduceColumns
from columnflow.tasks.ml import MLEvaluation
from columnflow.util import dev_sandbox
from columnflow.util import dev_sandbox, parse_selection_fields


class CreateHistograms(
Expand Down Expand Up @@ -148,12 +151,16 @@ def run(self):
self.config_inst.get_variable(var_name)
for var_name in law.util.flatten(self.variable_tuples.values())
)
for inp in (
for inp in ((
[variable_inst.expression]
if isinstance(variable_inst.expression, str)
# for variable_inst with custom expressions, read columns declared via aux key
else variable_inst.x("inputs", [])
)
) + (
parse_selection_fields(variable_inst.selection, only_fields=True)
if isinstance(variable_inst.selection, str)
else variable_inst.x("inputs", [])
))
}

# empty float array to use when input files have no entries
Expand Down Expand Up @@ -243,8 +250,24 @@ def expr(events, *args, **kwargs):
if len(events) == 0 and not has_ak_column(events, route):
return empty_f32
return route.apply(events, null_value=variable_inst.null_value)

# prepare the selection
mask = ak.Array(np.ones(len(events), dtype=np.bool))
sel = variable_inst.selection
if sel != "1":
if isinstance(sel, str):
selections = [
op(Route(s).apply(events, null_value=variable_inst.null_value), val)
for (s, op, val) in parse_selection_fields(variable_inst.selection)
]
selections = reduce(and_, selections)
mask = selections
elif callable(sel):
mask = sel(events)
else:
raise ValueError(f"invalid selection: {sel}")
# apply it
fill_data[variable_inst.name] = expr(events)
fill_data[variable_inst.name] = ak.where(mask, expr(events), variable_inst.null_value)

# fill it
fill_hist(
Expand Down
43 changes: 42 additions & 1 deletion columnflow/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"maybe_import", "import_plt", "import_ROOT", "import_file", "create_random_name", "expand_path",
"real_path", "ensure_dir", "wget", "call_thread", "call_proc", "ensure_proxy", "dev_sandbox",
"safe_div", "try_float", "try_complex", "try_int", "is_pattern", "is_regex", "pattern_matcher",
"dict_add_strict", "get_source_code",
"dict_add_strict", "get_source_code", "parse_selection_fields",
"DotDict", "MockModule", "FunctionArgs", "ClassPropertyDescriptor", "classproperty",
"DerivableMeta", "Derivable",
]
Expand All @@ -27,6 +27,7 @@
import re
import inspect
import multiprocessing
import operator
import multiprocessing.pool
from functools import wraps
from collections import OrderedDict
Expand Down Expand Up @@ -532,6 +533,46 @@ def get_source_code(obj: Any, indent: str | int = None) -> str:
return code


def parse_selection_fields(selection, only_fields=False):
"""
Parses the fields used in the selection string and returns them as a list.
"""
if not isinstance(selection, str):
return []
if selection == "1":
return []

# Find mask concatenation
pattern = r"(\||&)"
sels = re.split(pattern, selection) if re.search(pattern, selection) else [selection]

operations = {
">": operator.gt,
">=": operator.ge,
"<": operator.lt,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
}

results = []
op_pattern = r"([<>]=?|==|!=)"
for sel in sels:
if sel.strip() in ("|", "&"):
continue
if re.search(op_pattern, sel):
parts = re.split(op_pattern, sel, maxsplit=1)
parts = tuple(part.strip(" ()") for part in parts)
results.append((parts[0], operations[parts[1]], float(parts[2])))
else:
results.append((sel, None, None))

if only_fields:
return [r[0] for r in results]

return results


class DotDict(OrderedDict):
"""
Subclass of *OrderedDict* that provides read and write access to items via attributes by
Expand Down

0 comments on commit e725338

Please sign in to comment.