Skip to content

Commit

Permalink
Feature/get events from categories (columnflow#365)
Browse files Browse the repository at this point in the history
* implement helper function to select events from categories

* implement unit test for get_events_from_categories

* review comments

* review comments

---------

Co-authored-by: Marcel R <[email protected]>
  • Loading branch information
mafrahm and riga authored Dec 8, 2023
1 parent 60eb7e6 commit 250c6ec
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 0 deletions.
42 changes: 42 additions & 0 deletions columnflow/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,50 @@
import law
import order as od

from columnflow.util import maybe_import
from columnflow.types import Callable, Any, Sequence

ak = maybe_import("awkward")
np = maybe_import("numpy")


def get_events_from_categories(
events: ak.Array,
categories: Sequence[str | od.Category],
config_inst: od.Config | None = None,
) -> ak.Array:
"""
Helper function that returns all events from an awkward array *events* that are categorized
into one of the leafs of one of the *categories*.
:param events: Awkward array. Requires the 'category_ids' field to be present.
:param categories: Sequence of category instances. Can also be a sequence of strings when passing a
*config_inst*.
:param config_inst: Optional config instance to load category instances.
:raises ValueError: If "category_ids" is not present in the *events* fields.
:return: Awkward array of all events that are categorized into one of the leafs of one of the
*categories*
"""
if "category_ids" not in events.fields:
raise ValueError(
f"{get_events_from_categories.__name__} requires the 'category_ids' field to be present",
)

categories = law.util.make_list(categories)
if config_inst:
# get category insts
categories = [config_inst.get_category(cat) for cat in categories]

leaf_category_insts = set.union(*map(set, (cat.get_leaf_categories() or {cat} for cat in categories)))

# do the "or" of all leaf categories
mask = np.zeros(len(events), dtype=bool)
for cat in leaf_category_insts:
cat_mask = ak.any(events.category_ids == cat.id, axis=1)
mask = cat_mask | mask

return events[mask]


def get_root_processes_from_campaign(campaign: od.Campaign) -> od.UniqueObjectIndex:
"""
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
# import all tests
from .test_util import *
from .test_columnar_util import *
from .test_config_util import *
6 changes: 6 additions & 0 deletions tests/run_tests
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ action() {
ret="$?"
[ "${gret}" = "0" ] && gret="${ret}"

# test_config_util
echo
bash "${this_dir}/run_test" test_config_util "${cf_dir}/sandboxes/venv_columnar${dev}.sh"
ret="$?"
[ "${gret}" = "0" ] && gret="${ret}"

# test_task_parameters
echo
bash "${this_dir}/run_test" test_task_parameters
Expand Down
76 changes: 76 additions & 0 deletions tests/test_config_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# coding: utf-8


__all__ = ["ConfigUtilTests"]

import unittest

from columnflow.util import maybe_import
from columnflow.config_util import get_events_from_categories

import order as od

np = maybe_import("numpy")
ak = maybe_import("awkward")
dak = maybe_import("dask_awkward")
coffea = maybe_import("coffea")


class ConfigUtilTests(unittest.TestCase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.config_inst = cfg = od.Category("config", 1)

cfg.add_category("main_1", id=1)
main_2 = cfg.add_category("main_2", id=2)
main_2.add_category("leaf_21", id=21)
leaf_22 = main_2.add_category("leaf_22", id=22)
leaf_22.add_category("leaf_221", id=221)

# awkward array with category_ids (only leaf ids)
self.events = ak.Array({
"category_ids": [[1], [21, 221], [21], [221]],
"dummy_field": [1, 2, 3, 4],
})

def test_get_events_from_categories(self):
# check that category without leafs is working
events = get_events_from_categories(self.events, ["main_1"], self.config_inst)
self.assertTrue(ak.all(ak.any(events.category_ids == 1, axis=1)))
self.assertTrue(ak.all(events.dummy_field == ak.Array([1])))

# check that categories with leafs are working
events = get_events_from_categories(self.events, ["main_2"], self.config_inst)
self.assertTrue(ak.all(events.dummy_field == ak.Array([2, 3, 4])))

# check that leaf category is working
events = get_events_from_categories(self.events, ["leaf_221"], self.config_inst)
self.assertTrue(ak.all(ak.any(events.category_ids == 221, axis=1)))
self.assertTrue(ak.all(events.dummy_field == ak.Array([2, 4])))

# check that passing multiple categories is working
events = get_events_from_categories(self.events, ["main_1", "main_2"], self.config_inst)
self.assertTrue(ak.all(events.dummy_field == ak.Array([1, 2, 3, 4])))

# check that directly passing category inst is working
events = get_events_from_categories(self.events, self.config_inst.get_category("main_1"))
self.assertTrue(ak.all(ak.any(events.category_ids == 1, axis=1)))
self.assertTrue(ak.all(events.dummy_field == ak.Array([1])))

# never select events from non-leaf categories or not existing categories
events = get_events_from_categories(ak.Array({"category_ids": [[2], [-1], [99]]}), ["main_2"], self.config_inst)
self.assertEqual(len(events), 0)

# raises ValueError, when passing events without "category_ids" field
with self.assertRaises(ValueError):
get_events_from_categories(ak.Array({"no_category_ids": [1, 2, 3]}), ["main_2", self.config_inst])

# raises AttributeError, when passing categories as string, but not config_inst
with self.assertRaises(AttributeError):
get_events_from_categories(self.events, ["main_1", "main_2"])

# raises ValueError, when passing strings of nonexisting categories
with self.assertRaises(ValueError):
get_events_from_categories(self.events, ["nonexisting", "categories", "main_1"], self.config_inst)

0 comments on commit 250c6ec

Please sign in to comment.