Skip to content

Commit

Permalink
fix: setting keep branches with a list wasn't working correctly (#100)
Browse files Browse the repository at this point in the history
* Small fixes for the linter and drop_branches

* More little fixes

* linter fix

* linter
  • Loading branch information
zbilodea authored Jul 2, 2024
1 parent 292e447 commit 06ee606
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 94 deletions.
47 changes: 21 additions & 26 deletions src/hepconvert/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,38 +54,33 @@ def filter_branches(tree, keep_branches, drop_branches, count_branches):
if drop_branches and keep_branches:
msg = "Can specify either drop_branches or keep_branches, not both."
raise ValueError(msg) from None

if drop_branches:
if isinstance(drop_branches, dict): # noqa: SIM102
branches = drop_branches if drop_branches else keep_branches
keys = []
if branches:
if isinstance(branches, dict): # noqa: SIM102
if (
len(drop_branches) > 1
and tree.name in drop_branches
or tree.name == next(iter(drop_branches.keys()))
len(branches) > 1
and tree.name in branches
or tree.name == next(iter(branches.keys()))
):
drop_branches = drop_branches.get(tree.name)
if isinstance(drop_branches, str) or len(drop_branches) == 1:
drop_branches = tree.keys(filter_name=drop_branches)
return [
b.name
for b in tree.branches
if b.name not in count_branches and b.name not in drop_branches
]
if keep_branches:
if isinstance(keep_branches, dict): # noqa: SIM102
if (
len(keep_branches) > 1
and tree.name in keep_branches
or tree.name == next(iter(keep_branches.keys()))
):
keep_branches = keep_branches.get(tree.name)
if isinstance(keep_branches, str) or len(keep_branches) == 1:
keep_branches = tree.keys(filter_name=keep_branches)
keys = branches.get(tree.name)
if isinstance(branches, str) or len(branches) == 1:
keys = tree.keys(filter_name=branches)
else:
for i in branches:
keys = np.union1d(keys, tree.keys(filter_name=i))
if drop_branches:
return [
b.name
for b in tree.branches
if b.name not in count_branches and b.name in keep_branches
if b.name not in count_branches and b.name not in keys
]
return [b.name for b in tree.branches if b.name not in count_branches]
return [
b.name
for b in tree.branches
if b.name not in count_branches and b.name in keys
]
return [b.name for b in tree.branches]


def check_tqdm():
Expand Down
53 changes: 26 additions & 27 deletions src/hepconvert/copy_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@


def copy_root(
destination,
file,
out_file,
in_file,
*,
keep_branches=None,
drop_branches=None,
# add_branches=None, #TODO: add functionality for this, just specify about the counter issue?
keep_trees=None,
drop_trees=None,
cut=None,
Expand All @@ -32,15 +31,15 @@ def copy_root(
initial_basket_capacity=10,
resize_factor=10.0,
counter_name=lambda counted: "n" + counted,
step_size=100,
step_size="100 MB",
compression="ZLIB",
compression_level=1,
):
"""
:param destination: Name of the output file or file path.
:type destination: path-like
:param file: Local ROOT file to copy.
:type file: str
:param out_file: Name of the output file or file path.
:type out_file: path-like
:param in_file: Local ROOT file to copy.
:type in_file: str
:param keep_branches: To keep only certain branches and remove all others. To remove certain branches from all TTrees in the file,
pass a list of names of branches to keep, wildcarding accepted ("Jet_*"). If removing branches from one of multiple trees, pass a dict of structure: {tree: [branch1, branch2]}
to keep only branch1 and branch2 in ttree "tree". Defaults to None. Command line option: ``--keep-branches``.
Expand Down Expand Up @@ -114,7 +113,7 @@ def copy_root(
.. code-block:: bash
hepconvert copy-root [options] [OUT_FILE] [IN_FILE]
hepconvert copy-root [options] [of] [IN_FILE]
"""
if compression in ("ZLIB", "zlib"):
Expand All @@ -128,29 +127,29 @@ def copy_root(
else:
msg = f"unrecognized compression algorithm: {compression}. Only ZLIB, LZMA, LZ4, and ZSTD are accepted."
raise ValueError(msg)
path = Path(destination)
path = Path(out_file)
if Path.is_file(path):
if not force:
raise FileExistsError
out_file = uproot.recreate(
destination,
of = uproot.recreate(
out_file,
compression=uproot.compression.Compression.from_code_pair(
compression_code, compression_level
),
)
first = (True,)
else:
out_file = uproot.recreate(
destination,
of = uproot.recreate(
out_file,
compression=uproot.compression.Compression.from_code_pair(
compression_code, compression_level
),
)
first = (True,)
try:
f = uproot.open(file)
f = uproot.open(in_file)
except FileNotFoundError:
msg = "File: ", file, " does not exist or is corrupt."
msg = "file: ", in_file, " does not exist or is corrupt."
raise FileNotFoundError(msg) from None

hist_keys = f.keys(
Expand All @@ -159,11 +158,11 @@ def copy_root(

for key in hist_keys: # just pass to hadd??
if len(f[key].axes) == 1:
out_file[key] = _hadd_1d(destination, f, key, True)
of[key] = _hadd_1d(out_file, f, key, True)
elif len(f[key].axes) == 2:
out_file[key] = _hadd_2d(destination, f, key, True)
of[key] = _hadd_2d(out_file, f, key, True)
else:
out_file[key] = _hadd_3d(destination, f, key, True)
of[key] = _hadd_3d(out_file, f, key, True)

trees = f.keys(filter_classname="TTree", cycle=False, recursive=False)

Expand All @@ -179,7 +178,7 @@ def copy_root(
"Key '"
+ key
+ "' does not match any TTree in ROOT file"
+ str(file)
+ str(in_file)
)
raise ValueError(msg)
if isinstance(keep_trees, str):
Expand All @@ -197,7 +196,7 @@ def copy_root(
"Key '"
+ key
+ "' does not match any TTree in ROOT file"
+ str(file)
+ str(in_file)
)
raise ValueError(msg)
trees.remove(key)
Expand All @@ -212,16 +211,16 @@ def copy_root(
"TTree ",
key,
" does not match any TTree in ROOT file",
destination,
out_file,
)
raise ValueError(msg)

if len(trees) > 1 and progress_bar is not False:
if len(trees) > 1 and progress_bar is not False and progress_bar is not None:
number_of_items = len(trees)
if progress_bar is True:
tqdm = _utils.check_tqdm()
progress_bar = tqdm.tqdm(desc="Trees copied")
progress_bar.reset(total=number_of_items)
progress_bar.reset(total=number_of_items)
for t in trees:
tree = f[t]
count_branches = get_counter_branches(tree)
Expand Down Expand Up @@ -265,7 +264,7 @@ def copy_root(
}
else:
branch_types = {name: array.type for name, array in chunk.items()}
out_file.mktree(
of.mktree(
tree.name,
branch_types,
title=title,
Expand All @@ -277,9 +276,9 @@ def copy_root(

else:
try:
out_file[tree.name].extend(chunk)
of[tree.name].extend(chunk)
except AssertionError:
msg = "Are the branch-names correct?"
if len(trees) > 1 and progress_bar is not False:
if len(trees) > 1 and progress_bar is not False and progress_bar is not None:
progress_bar.update(n=1)
f.close()
9 changes: 5 additions & 4 deletions src/hepconvert/histogram_adding.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,11 @@ def add_histograms(
if not force and not append:
raise FileExistsError
if force and append:
msg = "Cannot append to a new file. Either force or append can be true."
msg = "Cannot append to a new file. Either force or append can be true, not both."
raise ValueError(msg)
if append:
out_file = uproot.update(destination)
elif force:
else:
out_file = uproot.recreate(
destination,
compression=uproot.compression.Compression.from_code_pair(
Expand Down Expand Up @@ -459,7 +459,8 @@ def add_histograms(
if progress_bar is True:
file_bar = tqdm.tqdm(desc="Files added")
hist_bar = tqdm.tqdm(desc="Histograms added")

else:
hist_bar = None
file_bar.reset(number_of_items)
if same_names:
if union:
Expand Down Expand Up @@ -490,7 +491,7 @@ def add_histograms(
msg = f"File: {input_file} does not exist or is corrupt."
raise FileNotFoundError(msg) from None
if same_names:
if progress_bar:
if progress_bar and hist_bar:
hist_bar.reset(len(keys))
for key in keys:
try:
Expand Down
3 changes: 1 addition & 2 deletions src/hepconvert/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,9 @@ def merge_root(
)
raise ValueError(msg)
if progress_bar is not False:
number_of_items = len(files)
if progress_bar is True:
tqdm = _utils.check_tqdm()
number_of_items = len(files)

progress_bar = tqdm.tqdm(desc="Files added")
progress_bar.reset(number_of_items)
for t in trees:
Expand Down
23 changes: 14 additions & 9 deletions src/hepconvert/root_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import awkward as ak
import uproot
from numpy import union1d


def root_to_parquet(
Expand Down Expand Up @@ -260,19 +261,23 @@ def _filter_branches(tree, keep_branches, drop_branches):
if drop_branches and keep_branches:
msg = "Can specify either drop_branches or keep_branches, not both."
raise ValueError(msg) from None

keys = []
if drop_branches:
if isinstance(drop_branches, str):
drop_branches = tree.keys(filter_name=drop_branches)
keys = tree.keys(filter_name=drop_branches)
if isinstance(drop_branches, dict) and tree.name in drop_branches:
drop_branches = drop_branches.get(tree.name)
return lambda b: b in [
b.name for b in tree.branches if b.name not in drop_branches
]
keys = drop_branches.get(tree.name)
else:
for i in drop_branches:
keys = union1d(keys, tree.keys(filter_name=i))
return lambda b: b in [b.name for b in tree.branches if b.name not in keys]
if keep_branches:
if isinstance(keep_branches, str):
keep_branches = tree.keys(filter_name=keep_branches)
keys = tree.keys(filter_name=keep_branches)
if isinstance(keep_branches, dict) and tree.name in keep_branches:
keep_branches = keep_branches.get(tree.name)
return lambda b: b in [b.name for b in tree.branches if b.name in keep_branches]
keys = keep_branches.get(tree.name)
else:
for i in keep_branches:
keys = union1d(keys, tree.keys(filter_name=i))
return lambda b: b in [b.name for b in tree.branches if b.name in keys]
return None
3 changes: 3 additions & 0 deletions tests/test_add_histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ def simple_1dim_F(tmp_path):
).all


simple_1dim_F("tests/samples")


def mult_2D_hists(tmp_path):
h1 = ROOT.TH2F("name", "", 10, 0.0, 10.0, 8, 0.0, 8.0)
data1 = [
Expand Down
Loading

0 comments on commit 06ee606

Please sign in to comment.