Skip to content

Commit

Permalink
Merge branch 'master' of github.com:haddadanas/columnflow into ROC_an…
Browse files Browse the repository at this point in the history
…d_tests
  • Loading branch information
haddadanas committed Dec 8, 2023
2 parents a560e84 + f901c9e commit ec10c7f
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 19 deletions.
42 changes: 42 additions & 0 deletions columnflow/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,50 @@
import law
import order as od

from columnflow.util import maybe_import
from columnflow.types import Callable, Any, Sequence

ak = maybe_import("awkward")
np = maybe_import("numpy")


def get_events_from_categories(
events: ak.Array,
categories: Sequence[str | od.Category],
config_inst: od.Config | None = None,
) -> ak.Array:
"""
Helper function that returns all events from an awkward array *events* that are categorized
into one of the leafs of one of the *categories*.
:param events: Awkward array. Requires the 'category_ids' field to be present.
:param categories: Sequence of category instances. Can also be a sequence of strings when passing a
*config_inst*.
:param config_inst: Optional config instance to load category instances.
:raises ValueError: If "category_ids" is not present in the *events* fields.
:return: Awkward array of all events that are categorized into one of the leafs of one of the
*categories*
"""
if "category_ids" not in events.fields:
raise ValueError(
f"{get_events_from_categories.__name__} requires the 'category_ids' field to be present",
)

categories = law.util.make_list(categories)
if config_inst:
# get category insts
categories = [config_inst.get_category(cat) for cat in categories]

leaf_category_insts = set.union(*map(set, (cat.get_leaf_categories() or {cat} for cat in categories)))

# do the "or" of all leaf categories
mask = np.zeros(len(events), dtype=bool)
for cat in leaf_category_insts:
cat_mask = ak.any(events.category_ids == cat.id, axis=1)
mask = cat_mask | mask

return events[mask]


def get_root_processes_from_campaign(campaign: od.Campaign) -> od.UniqueObjectIndex:
"""
Expand Down
27 changes: 25 additions & 2 deletions columnflow/tasks/framework/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import luigi

from columnflow.types import Any, Callable
from columnflow.tasks.framework.base import ConfigTask
from columnflow.tasks.framework.base import ConfigTask, RESOLVE_DEFAULT
from columnflow.tasks.framework.mixins import DatasetsProcessesMixin, VariablesMixin
from columnflow.tasks.framework.parameters import SettingsParameter, MultiSettingsParameter
from columnflow.util import DotDict, dict_add_strict
Expand Down Expand Up @@ -47,6 +47,13 @@ class PlotBase(ConfigTask):
description="Parameter to set a list of custom plotting parameters. Format: "
"'option1=val1,option2=val2,...'",
)
custom_style_config = luigi.Parameter(
default=RESOLVE_DEFAULT,
significant=False,
description="Parameter to overwrite the *style_config* that is passed to the plot function"
"via a dictionary in the `custom_style_config_groups` auxiliary in the config;"
"defaults to the `default_custom_style_config` aux",
)
skip_legend = law.OptionalBoolParameter(
default=None,
significant=False,
Expand All @@ -68,7 +75,7 @@ def resolve_param_values(cls, params):
return params
config_inst = params["config_inst"]

# resolve variable_settings
# resolve general_settings
if "general_settings" in params:
settings = params["general_settings"]
# when empty and default general_settings are defined, use them instead
Expand All @@ -94,6 +101,7 @@ def get_plot_parameters(self) -> DotDict:
dict_add_strict(params, "skip_legend", self.skip_legend)
dict_add_strict(params, "cms_label", self.cms_label)
dict_add_strict(params, "general_settings", self.general_settings)
dict_add_strict(params, "custom_style_config", self.custom_style_config)
return params

def get_plot_names(self, name: str) -> list[str]:
Expand Down Expand Up @@ -155,6 +163,21 @@ def update_plot_kwargs(self, kwargs: dict) -> dict:
for key, value in general_settings.items():
kwargs.setdefault(key, value)

# resolve custom_style_config
custom_style_config = kwargs.get("custom_style_config", None)
if custom_style_config == RESOLVE_DEFAULT:
custom_style_config = self.config_inst.x("default_custom_style_config", RESOLVE_DEFAULT)

groups = self.config_inst.x("custom_style_config_groups", {})
if isinstance(custom_style_config, str) and custom_style_config in groups.keys():
custom_style_config = groups[custom_style_config]

# update style_config
style_config = kwargs.get("style_config", {})
if isinstance(custom_style_config, dict) and isinstance(style_config, dict):
style_config = law.util.merge_dicts(custom_style_config, style_config)
kwargs["style_config"] = style_config

return kwargs


Expand Down
33 changes: 16 additions & 17 deletions sandboxes/_setup_venv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ setup_venv() {
# handle remote job environments
if [ "${CF_REMOTE_ENV}" = "1" ]; then
# in this case, the environment is inside a remote job, i.e., these variables are present:
# CF_ENV_BASH_SANDBOX_URIS, CF_ENV_BASH_SANDBOX_PATTERNS and CF_ENV_BASH_SANDBOX_NAMES
# CF_JOB_BASH_SANDBOX_URIS, CF_JOB_BASH_SANDBOX_PATTERNS and CF_JOB_BASH_SANDBOX_NAMES
if [ ! -f "${CF_SANDBOX_FLAG_FILE}" ]; then
if [ -z "${CF_WLCG_TOOLS}" ] || [ ! -f "${CF_WLCG_TOOLS}" ]; then
>&2 echo "CF_WLCG_TOOLS (${CF_WLCG_TOOLS}) files is empty or does not exist"
Expand All @@ -312,26 +312,25 @@ setup_venv() {

# fetch the bundle and unpack it
echo "looking for bash sandbox bundle for venv ${CF_VENV_NAME}"
local sandbox_names=( ${CF_ENV_BASH_SANDBOX_NAMES} )
local sandbox_uris=( ${CF_ENV_BASH_SANDBOX_URIS} )
local sandbox_patterns=( ${CF_ENV_BASH_SANDBOX_PATTERNS} )
local sandbox_names=( ${CF_JOB_BASH_SANDBOX_NAMES} )
local sandbox_uris=( ${CF_JOB_BASH_SANDBOX_URIS} )
local sandbox_patterns=( ${CF_JOB_BASH_SANDBOX_PATTERNS} )
local found_sandbox="false"
for (( i=0; i<${#sandbox_names[@]}; i+=1 )); do
if [ "${sandbox_names[i]}" = "${CF_VENV_NAME}" ]; then
echo "found bundle ${CF_VENV_NAME}, index ${i}, pattern ${sandbox_patterns[i]}, uri ${sandbox_uris[i]}"
(
source "${CF_WLCG_TOOLS}" "" &&
mkdir -p "${install_path}" &&
cd "${install_path}" &&
law_wlcg_get_file "${sandbox_uris[i]}" "${sandbox_patterns[i]}" "bundle.tgz" &&
tar -xzf "bundle.tgz"
) || return "$?"
found_sandbox="true"
break
fi
[ "${sandbox_names[i]}" != "${CF_VENV_NAME}" ] && continue
echo "found bundle ${CF_VENV_NAME}, index ${i}, pattern ${sandbox_patterns[i]}, uri ${sandbox_uris[i]}"
(
source "${CF_WLCG_TOOLS}" "" &&
mkdir -p "${install_path}" &&
cd "${install_path}" &&
law_wlcg_get_file "${sandbox_uris[i]}" "${sandbox_patterns[i]}" "bundle.tgz" &&
tar -xzf "bundle.tgz"
) || return "$?"
found_sandbox="true"
break
done
if ! ${found_sandbox}; then
>&2 echo "bash sandbox ${CF_VENV_BASE} not found in job configuration, stopping"
>&2 echo "bash sandbox '${CF_VENV_NAME}' not found in job configuration, stopping"
return "31"
fi
fi
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
# import all tests
from .test_util import *
from .test_columnar_util import *
from .test_config_util import *
from .test_task_parameters import *
from .test_plotting import *
6 changes: 6 additions & 0 deletions tests/run_tests
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ action() {
ret="$?"
[ "${gret}" = "0" ] && gret="${ret}"

# test_config_util
echo
bash "${this_dir}/run_test" test_config_util "${cf_dir}/sandboxes/venv_columnar${dev}.sh"
ret="$?"
[ "${gret}" = "0" ] && gret="${ret}"

# test_task_parameters
echo
bash "${this_dir}/run_test" test_task_parameters
Expand Down
76 changes: 76 additions & 0 deletions tests/test_config_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# coding: utf-8


__all__ = ["ConfigUtilTests"]

import unittest

from columnflow.util import maybe_import
from columnflow.config_util import get_events_from_categories

import order as od

np = maybe_import("numpy")
ak = maybe_import("awkward")
dak = maybe_import("dask_awkward")
coffea = maybe_import("coffea")


class ConfigUtilTests(unittest.TestCase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.config_inst = cfg = od.Category("config", 1)

cfg.add_category("main_1", id=1)
main_2 = cfg.add_category("main_2", id=2)
main_2.add_category("leaf_21", id=21)
leaf_22 = main_2.add_category("leaf_22", id=22)
leaf_22.add_category("leaf_221", id=221)

# awkward array with category_ids (only leaf ids)
self.events = ak.Array({
"category_ids": [[1], [21, 221], [21], [221]],
"dummy_field": [1, 2, 3, 4],
})

def test_get_events_from_categories(self):
# check that category without leafs is working
events = get_events_from_categories(self.events, ["main_1"], self.config_inst)
self.assertTrue(ak.all(ak.any(events.category_ids == 1, axis=1)))
self.assertTrue(ak.all(events.dummy_field == ak.Array([1])))

# check that categories with leafs are working
events = get_events_from_categories(self.events, ["main_2"], self.config_inst)
self.assertTrue(ak.all(events.dummy_field == ak.Array([2, 3, 4])))

# check that leaf category is working
events = get_events_from_categories(self.events, ["leaf_221"], self.config_inst)
self.assertTrue(ak.all(ak.any(events.category_ids == 221, axis=1)))
self.assertTrue(ak.all(events.dummy_field == ak.Array([2, 4])))

# check that passing multiple categories is working
events = get_events_from_categories(self.events, ["main_1", "main_2"], self.config_inst)
self.assertTrue(ak.all(events.dummy_field == ak.Array([1, 2, 3, 4])))

# check that directly passing category inst is working
events = get_events_from_categories(self.events, self.config_inst.get_category("main_1"))
self.assertTrue(ak.all(ak.any(events.category_ids == 1, axis=1)))
self.assertTrue(ak.all(events.dummy_field == ak.Array([1])))

# never select events from non-leaf categories or not existing categories
events = get_events_from_categories(ak.Array({"category_ids": [[2], [-1], [99]]}), ["main_2"], self.config_inst)
self.assertEqual(len(events), 0)

# raises ValueError, when passing events without "category_ids" field
with self.assertRaises(ValueError):
get_events_from_categories(ak.Array({"no_category_ids": [1, 2, 3]}), ["main_2", self.config_inst])

# raises AttributeError, when passing categories as string, but not config_inst
with self.assertRaises(AttributeError):
get_events_from_categories(self.events, ["main_1", "main_2"])

# raises ValueError, when passing strings of nonexisting categories
with self.assertRaises(ValueError):
get_events_from_categories(self.events, ["nonexisting", "categories", "main_1"], self.config_inst)

0 comments on commit ec10c7f

Please sign in to comment.