Skip to content

Commit

Permalink
feat: wildcard branch names for adding and dropping (#67)
Browse files Browse the repository at this point in the history
* added branch skimming to copy_root

* added filtering to root_to_parquet

* pylint fixes

* formatting...

* added keep branches and slimming and wildcard selection to merge
  • Loading branch information
zbilodea authored Feb 12, 2024
1 parent 984575c commit 7d9d2ea
Show file tree
Hide file tree
Showing 9 changed files with 1,186 additions and 189 deletions.
25 changes: 9 additions & 16 deletions src/hepconvert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def main() -> None:
help="Use a compression level particular to the chosen compressor. By default the compression level is 1.",
)
@click.option(
"-f",
"--force",
default=True,
default=False,
help="If True, overwrites destination file if it already exists.",
)
def parquet_to_root(
Expand All @@ -57,7 +58,7 @@ def parquet_to_root(
resize_factor=10.0,
compression="zlib",
compression_level=1,
force=True,
force=False,
):
"""
Convert Parquet file to ROOT file.
Expand Down Expand Up @@ -97,8 +98,9 @@ def parquet_to_root(
help="When the TTree metadata needs to be rewritten, this specifies how many more TBasket slots to allocate as a multiplicative factor.",
)
@click.option(
"-f",
"--force",
default=True,
default=False,
help="If True, overwrites destination file if it already exists.",
)
def copy_root(
Expand All @@ -107,7 +109,7 @@ def copy_root(
*,
drop_branches=None,
drop_trees=None,
force=True,
force=False,
title="",
field_name=lambda outer, inner: inner if outer == "" else outer + "_" + inner,
initial_basket_capacity=10,
Expand Down Expand Up @@ -145,7 +147,7 @@ def copy_root(
@click.option(
"-f",
"--force",
default=True,
default=False,
help="Overwrite destination file if it already exists",
)
@click.option("--append", default=False, help="Append histograms to an existing file")
Expand Down Expand Up @@ -178,7 +180,7 @@ def add(
destination,
files,
*,
force=True,
force=False,
append=False,
compression="zlib",
compression_level=1,
Expand Down Expand Up @@ -207,13 +209,6 @@ def add(
@main.command()
@click.argument("destination")
@click.argument("files")
@click.option(
"--branch-types",
default=None,
type=dict,
required=False,
help="Manually enter branch names and types to improve performance slightly.",
)
@click.option("--title", required=False, default="", help="Set title of new TTree.")
@click.option(
"--initial-basket-capacity",
Expand Down Expand Up @@ -254,7 +249,6 @@ def merge_root(
files,
*,
fieldname_separator="_",
branch_types=None,
title="",
field_name=lambda outer, inner: inner if outer == "" else outer + "_" + inner,
initial_basket_capacity=10,
Expand All @@ -276,7 +270,6 @@ def merge_root(
destination,
files,
fieldname_separator=fieldname_separator,
branch_types=branch_types,
title=title,
field_name=field_name,
initial_basket_capacity=initial_basket_capacity,
Expand Down Expand Up @@ -304,7 +297,7 @@ def merge_root(
@click.option(
"-f",
"--force",
default=True,
default=False,
type=bool,
help="If a file already exists at specified path, it gets replaced",
)
Expand Down
86 changes: 86 additions & 0 deletions src/hepconvert/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# def get_counter_branches(tree):
# counter_branches =
from __future__ import annotations

import numpy as np


def group_branches(tree, keep_branches):
"""
Creates groups for ak.zip to avoid duplicate counters being created.
Groups created if branches have the same .member("fLeafCount")
"""
groups = []
count_branches = []
temp_branches = tree.keys(filter_name=keep_branches)
temp_branches1 = tree.keys(filter_name=keep_branches)
cur_group = 0
for branch in temp_branches:
if len(tree[branch].member("fLeaves")) > 1:
msg = "Cannot handle split objects."
raise NotImplementedError(msg)
if tree[branch].member("fLeaves")[0].member("fLeafCount") is None:
continue
groups.append([])
groups[cur_group].append(branch)
for branch1 in temp_branches1:
if tree[branch].member("fLeaves")[0].member("fLeafCount") is tree[
branch1
].member("fLeaves")[0].member("fLeafCount") and (
tree[branch].name != tree[branch1].name
):
groups[cur_group].append(branch1)
temp_branches.remove(branch1)
count_branches.append(tree[branch].count_branch.name)
temp_branches.remove(branch)
cur_group += 1
return groups, count_branches


def get_counter_branches(tree):
"""
Gets counter branches to remove them in merge etc.
"""
count_branches = []
for branch in tree.keys(): # noqa: SIM118
if tree[branch].member("fLeaves")[0].member("fLeafCount") is None:
continue
count_branches.append(tree[branch].count_branch.name)
return np.unique(count_branches, axis=0)


def filter_branches(tree, keep_branches, drop_branches, count_branches):
"""
Creates lambda function for filtering branches based on keep_branches or drop_branches.
"""
if drop_branches:
if isinstance(drop_branches, dict): # noqa: SIM102
if (
len(drop_branches) > 1
and tree.name in drop_branches
or tree.name == next(iter(drop_branches.keys()))
):
drop_branches = drop_branches.get(tree.name)
if isinstance(drop_branches, str) or len(drop_branches) == 1:
drop_branches = tree.keys(filter_name=drop_branches)
return [
b.name
for b in tree.branches
if b.name not in count_branches and b.name not in drop_branches
]
if keep_branches:
if isinstance(keep_branches, dict): # noqa: SIM102
if (
len(keep_branches) > 1
and tree.name in keep_branches
or tree.name == next(iter(keep_branches.keys()))
):
keep_branches = keep_branches.get(tree.name)
if isinstance(keep_branches, str):
keep_branches = tree.keys(filter_name=keep_branches)
return [
b.name
for b in tree.branches
if b.name not in count_branches and b.name in keep_branches
]
return [b.name for b in tree.branches if b.name not in count_branches]
129 changes: 37 additions & 92 deletions src/hepconvert/copy_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import awkward as ak
import uproot

from hepconvert._utils import filter_branches, get_counter_branches, group_branches
from hepconvert.histogram_adding import _hadd_1d, _hadd_2d, _hadd_3d

# ruff: noqa: B023
Expand All @@ -14,11 +15,14 @@ def copy_root(
destination,
file,
*,
keep_branches=None,
drop_branches=None,
# add_branches=None, #TO-DO: add functionality for this, just specify about the counter issue
keep_trees=None,
drop_trees=None,
force=True,
force=False,
fieldname_separator="_",
# fix_duplicate_counters=False, #TO-DO: ask about this?
title="",
field_name=lambda outer, inner: inner if outer == "" else outer + "_" + inner,
initial_basket_capacity=10,
Expand All @@ -40,6 +44,8 @@ def copy_root(
:param drop_trees: To remove a ttree from a file, pass a list of names of branches to remove.
Defaults to None. Command line option: ``--drop-trees``.
:type drop_trees: str or list of str, optional
:param force: If true, replaces file if it already exists. Default is False. Command line options ``-f`` or ``--force``.
:type force: Bool, optional
:param fieldname_separator: If data includes jagged arrays, pass the character that separates
TBranch names for columns, used for grouping columns (to avoid duplicate counters in ROOT file). Defaults to "_".
:type fieldname_separator: str, optional
Expand Down Expand Up @@ -128,21 +134,34 @@ def copy_root(
filter_classname=["TH*", "TProfile"], cycle=False, recursive=False
)

for key in f.keys(cycle=False, recursive=False):
if key in hist_keys:
if len(f[key].axes) == 1:
h_sum = _hadd_1d(destination, f, key, True)
# if isinstance(h_sum, uproot.models.TH.Model_TH1F_v3):
# print(h_sum.member('fXaxis'))
out_file[key] = h_sum
elif len(f[key].axes) == 2:
out_file[key] = _hadd_2d(destination, f, key, True)
else:
out_file[key] = _hadd_3d(destination, f, key, True)
for key in hist_keys: # just pass to hadd??
if len(f[key].axes) == 1:
out_file[key] = _hadd_1d(destination, f, key, True)
elif len(f[key].axes) == 2:
out_file[key] = _hadd_2d(destination, f, key, True)
else:
out_file[key] = _hadd_3d(destination, f, key, True)

trees = f.keys(filter_classname="TTree", cycle=False, recursive=False)

# Check that drop_trees keys are valid/refer to a tree:
if keep_trees:
if isinstance(keep_trees, list):
for key in keep_trees:
if key not in trees:
msg = (
"Key '"
+ key
+ "' does not match any TTree in ROOT file"
+ str(file)
)
raise ValueError(msg)
if isinstance(keep_trees, str):
keep_trees = f.keys(filter_name=keep_trees, cycle=False)
if len(keep_trees) != 1:
drop_trees = [tree for tree in trees if tree not in keep_trees]
else:
drop_trees = [tree for tree in trees if tree != keep_trees[0]]
if drop_trees:
if isinstance(drop_trees, list):
for key in drop_trees:
Expand Down Expand Up @@ -172,85 +191,14 @@ def copy_root(

for t in trees:
tree = f[t]
histograms = tree.keys(filter_typename=["TH*", "TProfile"], recursive=False)
groups = []
count_branches = []
temp_branches = [branch.name for branch in tree.branches]
temp_branches1 = [branch.name for branch in tree.branches]
cur_group = 0
for branch in temp_branches:
if len(tree[branch].member("fLeaves")) > 1:
msg = "Cannot handle split objects."
raise NotImplementedError(msg)
if tree[branch].member("fLeaves")[0].member("fLeafCount") is None:
continue
groups.append([])
groups[cur_group].append(branch)
for branch1 in temp_branches1:
if tree[branch].member("fLeaves")[0].member("fLeafCount") is tree[
branch1
].member("fLeaves")[0].member("fLeafCount") and (
tree[branch].name != tree[branch1].name
):
groups[cur_group].append(branch1)
temp_branches.remove(branch1)
count_branches.append(tree[branch].count_branch.name)
temp_branches.remove(tree[branch].count_branch.name)
temp_branches.remove(branch)
cur_group += 1

if drop_branches:
if isinstance(drop_branches, dict) and t in drop_branches:
rm = drop_branches.get(t)
else:
rm = drop_branches
if isinstance(rm, list):
keep_branches = [
branch.name
for branch in tree.branches
if branch.name not in rm and branch.name not in count_branches
]
elif isinstance(rm, str):
keep_branches = [
branch.name
for branch in tree.branches
if branch.name != rm and branch.name not in count_branches
]
else:
keep_branches = [
branch.name
for branch in tree.branches
if branch.name not in count_branches
]

writable_hists = {}
if len(histograms) > 1:
for key in histograms:
if len(f[key].axes) == 1:
writable_hists[key] = _hadd_1d(destination, f, key, True)

elif len(f[key].axes) == 2:
writable_hists[key] = _hadd_2d(destination, f, key, True)

else:
writable_hists[key] = _hadd_3d(destination, f, key, True)

elif len(histograms) == 1:
if len(f[histograms[0]].axes) == 1:
writable_hists = _hadd_1d(destination, f, histograms[0], True)

elif len(f[histograms[0]].axes) == 2:
writable_hists = _hadd_2d(destination, f, histograms[0], True)

else:
writable_hists = _hadd_3d(destination, f, histograms[0], True)

count_branches = get_counter_branches(tree)
kb = filter_branches(tree, keep_branches, drop_branches, count_branches)
groups, count_branches = group_branches(tree, kb)
first = True
for chunk in uproot.iterate(
tree,
for chunk in tree.iterate(
step_size=step_size,
how=dict,
filter_branch=lambda b: b.name in keep_branches,
filter_name=lambda b: b in kb,
):
for group in groups:
if (len(group)) > 1:
Expand All @@ -270,7 +218,7 @@ def copy_root(
}
)
for key in group:
if key in keep_branches:
if key in kb:
del chunk[key]
if first:
if drop_branches:
Expand Down Expand Up @@ -303,7 +251,4 @@ def copy_root(
except AssertionError:
msg = "Are the branch-names correct?"

for i, _value in enumerate(histograms):
out_file[histograms[i]] = writable_hists[i]

f.close()
Loading

0 comments on commit 7d9d2ea

Please sign in to comment.