Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: wildcard branch names for adding and dropping #67

Merged
merged 10 commits into from
Feb 12, 2024
25 changes: 9 additions & 16 deletions src/hepconvert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def main() -> None:
help="Use a compression level particular to the chosen compressor. By default the compression level is 1.",
)
@click.option(
"-f",
"--force",
default=True,
default=False,
help="If True, overwrites destination file if it already exists.",
)
def parquet_to_root(
Expand All @@ -57,7 +58,7 @@ def parquet_to_root(
resize_factor=10.0,
compression="zlib",
compression_level=1,
force=True,
force=False,
):
"""
Convert Parquet file to ROOT file.
Expand Down Expand Up @@ -97,8 +98,9 @@ def parquet_to_root(
help="When the TTree metadata needs to be rewritten, this specifies how many more TBasket slots to allocate as a multiplicative factor.",
)
@click.option(
"-f",
"--force",
default=True,
default=False,
help="If True, overwrites destination file if it already exists.",
)
def copy_root(
Expand All @@ -107,7 +109,7 @@ def copy_root(
*,
drop_branches=None,
drop_trees=None,
force=True,
force=False,
title="",
field_name=lambda outer, inner: inner if outer == "" else outer + "_" + inner,
initial_basket_capacity=10,
Expand Down Expand Up @@ -145,7 +147,7 @@ def copy_root(
@click.option(
"-f",
"--force",
default=True,
default=False,
help="Overwrite destination file if it already exists",
)
@click.option("--append", default=False, help="Append histograms to an existing file")
Expand Down Expand Up @@ -178,7 +180,7 @@ def add(
destination,
files,
*,
force=True,
force=False,
append=False,
compression="zlib",
compression_level=1,
Expand Down Expand Up @@ -207,13 +209,6 @@ def add(
@main.command()
@click.argument("destination")
@click.argument("files")
@click.option(
"--branch-types",
default=None,
type=dict,
required=False,
help="Manually enter branch names and types to improve performance slightly.",
)
@click.option("--title", required=False, default="", help="Set title of new TTree.")
@click.option(
"--initial-basket-capacity",
Expand Down Expand Up @@ -254,7 +249,6 @@ def merge_root(
files,
*,
fieldname_separator="_",
branch_types=None,
title="",
field_name=lambda outer, inner: inner if outer == "" else outer + "_" + inner,
initial_basket_capacity=10,
Expand All @@ -276,7 +270,6 @@ def merge_root(
destination,
files,
fieldname_separator=fieldname_separator,
branch_types=branch_types,
title=title,
field_name=field_name,
initial_basket_capacity=initial_basket_capacity,
Expand Down Expand Up @@ -304,7 +297,7 @@ def merge_root(
@click.option(
"-f",
"--force",
default=True,
default=False,
type=bool,
help="If a file already exists at specified path, it gets replaced",
)
Expand Down
86 changes: 86 additions & 0 deletions src/hepconvert/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# def get_counter_branches(tree):
# counter_branches =
from __future__ import annotations

import numpy as np


def group_branches(tree, keep_branches):
"""
Creates groups for ak.zip to avoid duplicate counters being created.
Groups created if branches have the same .member("fLeafCount")
"""
groups = []
count_branches = []
temp_branches = tree.keys(filter_name=keep_branches)
temp_branches1 = tree.keys(filter_name=keep_branches)
cur_group = 0
for branch in temp_branches:
if len(tree[branch].member("fLeaves")) > 1:
msg = "Cannot handle split objects."
raise NotImplementedError(msg)
if tree[branch].member("fLeaves")[0].member("fLeafCount") is None:
continue
groups.append([])
groups[cur_group].append(branch)
for branch1 in temp_branches1:
if tree[branch].member("fLeaves")[0].member("fLeafCount") is tree[
branch1
].member("fLeaves")[0].member("fLeafCount") and (
tree[branch].name != tree[branch1].name
):
groups[cur_group].append(branch1)
temp_branches.remove(branch1)
count_branches.append(tree[branch].count_branch.name)
temp_branches.remove(branch)
cur_group += 1
return groups, count_branches


def get_counter_branches(tree):
"""
Gets counter branches to remove them in merge etc.
"""
count_branches = []
for branch in tree.keys(): # noqa: SIM118
if tree[branch].member("fLeaves")[0].member("fLeafCount") is None:
continue
count_branches.append(tree[branch].count_branch.name)
return np.unique(count_branches, axis=0)


def filter_branches(tree, keep_branches, drop_branches, count_branches):
"""
Creates lambda function for filtering branches based on keep_branches or drop_branches.
"""
if drop_branches:
if isinstance(drop_branches, dict): # noqa: SIM102
if (
len(drop_branches) > 1
and tree.name in drop_branches
or tree.name == next(iter(drop_branches.keys()))
):
drop_branches = drop_branches.get(tree.name)
if isinstance(drop_branches, str) or len(drop_branches) == 1:
drop_branches = tree.keys(filter_name=drop_branches)
return [
b.name
for b in tree.branches
if b.name not in count_branches and b.name not in drop_branches
]
if keep_branches:
if isinstance(keep_branches, dict): # noqa: SIM102
if (
len(keep_branches) > 1
and tree.name in keep_branches
or tree.name == next(iter(keep_branches.keys()))
):
keep_branches = keep_branches.get(tree.name)
if isinstance(keep_branches, str):
keep_branches = tree.keys(filter_name=keep_branches)
return [
b.name
for b in tree.branches
if b.name not in count_branches and b.name in keep_branches
]
return [b.name for b in tree.branches if b.name not in count_branches]
129 changes: 37 additions & 92 deletions src/hepconvert/copy_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import awkward as ak
import uproot

from hepconvert._utils import filter_branches, get_counter_branches, group_branches
from hepconvert.histogram_adding import _hadd_1d, _hadd_2d, _hadd_3d

# ruff: noqa: B023
Expand All @@ -14,11 +15,14 @@ def copy_root(
destination,
file,
*,
keep_branches=None,
drop_branches=None,
# add_branches=None, #TO-DO: add functionality for this, just specify about the counter issue
keep_trees=None,
drop_trees=None,
force=True,
force=False,
fieldname_separator="_",
# fix_duplicate_counters=False, #TO-DO: ask about this?
title="",
field_name=lambda outer, inner: inner if outer == "" else outer + "_" + inner,
initial_basket_capacity=10,
Expand All @@ -40,6 +44,8 @@ def copy_root(
:param drop_trees: To remove a ttree from a file, pass a list of names of branches to remove.
Defaults to None. Command line option: ``--drop-trees``.
:type drop_trees: str or list of str, optional
:param force: If true, replaces file if it already exists. Default is False. Command line options ``-f`` or ``--force``.
:type force: Bool, optional
:param fieldname_separator: If data includes jagged arrays, pass the character that separates
TBranch names for columns, used for grouping columns (to avoid duplicate counters in ROOT file). Defaults to "_".
:type fieldname_separator: str, optional
Expand Down Expand Up @@ -128,21 +134,34 @@ def copy_root(
filter_classname=["TH*", "TProfile"], cycle=False, recursive=False
)

for key in f.keys(cycle=False, recursive=False):
if key in hist_keys:
if len(f[key].axes) == 1:
h_sum = _hadd_1d(destination, f, key, True)
# if isinstance(h_sum, uproot.models.TH.Model_TH1F_v3):
# print(h_sum.member('fXaxis'))
out_file[key] = h_sum
elif len(f[key].axes) == 2:
out_file[key] = _hadd_2d(destination, f, key, True)
else:
out_file[key] = _hadd_3d(destination, f, key, True)
for key in hist_keys: # just pass to hadd??
if len(f[key].axes) == 1:
out_file[key] = _hadd_1d(destination, f, key, True)
elif len(f[key].axes) == 2:
out_file[key] = _hadd_2d(destination, f, key, True)
else:
out_file[key] = _hadd_3d(destination, f, key, True)

trees = f.keys(filter_classname="TTree", cycle=False, recursive=False)

# Check that drop_trees keys are valid/refer to a tree:
if keep_trees:
if isinstance(keep_trees, list):
for key in keep_trees:
if key not in trees:
msg = (
"Key '"
+ key
+ "' does not match any TTree in ROOT file"
+ str(file)
)
raise ValueError(msg)
if isinstance(keep_trees, str):
keep_trees = f.keys(filter_name=keep_trees, cycle=False)
if len(keep_trees) != 1:
drop_trees = [tree for tree in trees if tree not in keep_trees]
else:
drop_trees = [tree for tree in trees if tree != keep_trees[0]]
if drop_trees:
if isinstance(drop_trees, list):
for key in drop_trees:
Expand Down Expand Up @@ -172,85 +191,14 @@ def copy_root(

for t in trees:
tree = f[t]
histograms = tree.keys(filter_typename=["TH*", "TProfile"], recursive=False)
groups = []
count_branches = []
temp_branches = [branch.name for branch in tree.branches]
temp_branches1 = [branch.name for branch in tree.branches]
cur_group = 0
for branch in temp_branches:
if len(tree[branch].member("fLeaves")) > 1:
msg = "Cannot handle split objects."
raise NotImplementedError(msg)
if tree[branch].member("fLeaves")[0].member("fLeafCount") is None:
continue
groups.append([])
groups[cur_group].append(branch)
for branch1 in temp_branches1:
if tree[branch].member("fLeaves")[0].member("fLeafCount") is tree[
branch1
].member("fLeaves")[0].member("fLeafCount") and (
tree[branch].name != tree[branch1].name
):
groups[cur_group].append(branch1)
temp_branches.remove(branch1)
count_branches.append(tree[branch].count_branch.name)
temp_branches.remove(tree[branch].count_branch.name)
temp_branches.remove(branch)
cur_group += 1

if drop_branches:
if isinstance(drop_branches, dict) and t in drop_branches:
rm = drop_branches.get(t)
else:
rm = drop_branches
if isinstance(rm, list):
keep_branches = [
branch.name
for branch in tree.branches
if branch.name not in rm and branch.name not in count_branches
]
elif isinstance(rm, str):
keep_branches = [
branch.name
for branch in tree.branches
if branch.name != rm and branch.name not in count_branches
]
else:
keep_branches = [
branch.name
for branch in tree.branches
if branch.name not in count_branches
]

writable_hists = {}
if len(histograms) > 1:
for key in histograms:
if len(f[key].axes) == 1:
writable_hists[key] = _hadd_1d(destination, f, key, True)

elif len(f[key].axes) == 2:
writable_hists[key] = _hadd_2d(destination, f, key, True)

else:
writable_hists[key] = _hadd_3d(destination, f, key, True)

elif len(histograms) == 1:
if len(f[histograms[0]].axes) == 1:
writable_hists = _hadd_1d(destination, f, histograms[0], True)

elif len(f[histograms[0]].axes) == 2:
writable_hists = _hadd_2d(destination, f, histograms[0], True)

else:
writable_hists = _hadd_3d(destination, f, histograms[0], True)

count_branches = get_counter_branches(tree)
kb = filter_branches(tree, keep_branches, drop_branches, count_branches)
groups, count_branches = group_branches(tree, kb)
first = True
for chunk in uproot.iterate(
tree,
for chunk in tree.iterate(
step_size=step_size,
how=dict,
filter_branch=lambda b: b.name in keep_branches,
filter_name=lambda b: b in kb,
):
for group in groups:
if (len(group)) > 1:
Expand All @@ -270,7 +218,7 @@ def copy_root(
}
)
for key in group:
if key in keep_branches:
if key in kb:
del chunk[key]
if first:
if drop_branches:
Expand Down Expand Up @@ -303,7 +251,4 @@ def copy_root(
except AssertionError:
msg = "Are the branch-names correct?"

for i, _value in enumerate(histograms):
out_file[histograms[i]] = writable_hists[i]

f.close()
Loading
Loading