Skip to content

Commit

Permalink
feat: add support for more RNTuple types (#1223)
Browse files Browse the repository at this point in the history
* Add support for std::atomic<T>

* Add support for std::bitset

* Don't prune RNTuple records

* Fixed bug reading page description

* Attempt at supporting invalid variants

* Added tests

* Treat invalid variants as None

* Specify dtype for np.arange
  • Loading branch information
ariostas authored Aug 9, 2024
1 parent 90039ea commit 80e7803
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 15 deletions.
58 changes: 43 additions & 15 deletions src/uproot/models/RNTuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_rntuple_cluster_summary_format = struct.Struct("<QQ")
_rntuple_checksum_format = struct.Struct("<Q")
_rntuple_envlink_size_format = struct.Struct("<Q")
_rntuple_page_num_elements_format = struct.Struct("<I")
_rntuple_page_num_elements_format = struct.Struct("<i")
_rntuple_column_group_id_format = struct.Struct("<I")
_rntuple_first_ele_index_format = struct.Struct("<I")

Expand Down Expand Up @@ -338,15 +338,26 @@ def field_form(self, this_id, seen):
structural_role == uproot.const.rntuple_role_leaf
and this_record.repetition == 0
):
# deal with std::atomic
# they have no associated column, but exactly one subfield containing the underlying data
tmp_id = self._alias_columns_dict.get(this_id, this_id)
if (
tmp_id not in self._column_records_dict
and len(self._related_ids[tmp_id]) == 1
):
this_id = self._related_ids[tmp_id][0]
seen.add(this_id)
# base case of recursion
# n.b. the split may happen in column
return self.col_form(this_id)
elif structural_role == uproot.const.rntuple_role_leaf:
# std::array it only has one child
if this_id in self._related_ids:
# std::array has only one subfield
child_id = self._related_ids[this_id][0]

inner = self.field_form(child_id, seen)
inner = self.field_form(child_id, seen)
else:
# std::bitset has no subfields, so we use it directly
inner = self.col_form(this_id)
keyname = f"RegularForm-{this_id}"
return ak.forms.RegularForm(inner, this_record.repetition, form_key=keyname)
elif structural_role == uproot.const.rntuple_role_vector:
Expand Down Expand Up @@ -387,7 +398,10 @@ def field_form(self, this_id, seen):
if this_id in self._related_ids:
newids = self._related_ids[this_id]
recordlist = [self.field_form(i, seen) for i in newids]
return ak.forms.UnionForm("i8", "i64", recordlist, form_key=keyname)
inner = ak.forms.UnionForm(
"i8", "i64", recordlist, form_key=keyname + "-union"
)
return ak.forms.IndexedOptionForm("i64", inner, form_key=keyname)
else:
# everything should recurse above this branch
raise AssertionError("this should be unreachable")
Expand Down Expand Up @@ -538,13 +552,15 @@ def arrays(
[c.num_entries for c in clusters[start_cluster_idx:stop_cluster_idx]]
)

form = self.to_akform().select_columns(filter_names)
form = self.to_akform().select_columns(
filter_names, prune_unions_and_records=False
)
# only read columns mentioned in the awkward form
target_cols = []
container_dict = {}
_recursive_find(form, target_cols)
for key in target_cols:
if "column" in key:
if "column" in key and "union" not in key:
key_nr = int(key.split("-")[1])
dtype_byte = self.column_records[key_nr].type
content = self.read_col_pages(
Expand All @@ -556,18 +572,30 @@ def arrays(
content = numpy.diff(content)
if dtype_byte == uproot.const.rntuple_col_type_to_num_dict["switch"]:
kindex, tags = _split_switch_bits(content)
container_dict[f"{key}-index"] = kindex
container_dict[f"{key}-tags"] = tags
# Find invalid variants and adjust buffers accordingly
invalid = numpy.flatnonzero(tags == -1)
if len(invalid) > 0:
kindex = numpy.delete(kindex, invalid)
tags = numpy.delete(tags, invalid)
invalid -= numpy.arange(len(invalid))
optional_index = numpy.insert(
numpy.arange(len(kindex), dtype=numpy.int64), invalid, -1
)
else:
optional_index = numpy.arange(len(kindex), dtype=numpy.int64)
container_dict[f"{key}-index"] = optional_index
container_dict[f"{key}-union-index"] = kindex
container_dict[f"{key}-union-tags"] = tags
else:
# don't distinguish data and offsets
container_dict[f"{key}-data"] = content
container_dict[f"{key}-offsets"] = content
cluster_offset = cluster_starts[start_cluster_idx]
entry_start -= cluster_offset
entry_stop -= cluster_offset
return ak.from_buffers(form, cluster_num_entries, container_dict)[
entry_start:entry_stop
]
return ak.from_buffers(
form, cluster_num_entries, container_dict, allow_noncanonical_form=True
)[entry_start:entry_stop]


# Supporting function and classes
Expand All @@ -592,9 +620,9 @@ def _recursive_find(form, res):
class PageDescription:
def read(self, chunk, cursor, context):
out = MetaData(type(self).__name__)
out.num_elements = cursor.field(
chunk, _rntuple_page_num_elements_format, context
)
num_elements = cursor.field(chunk, _rntuple_page_num_elements_format, context)
out.has_checksum = num_elements < 0
out.num_elements = abs(num_elements)
out.locator = LocatorReader().read(chunk, cursor, context)
return out

Expand Down
87 changes: 87 additions & 0 deletions tests/test_1223_more_rntuple_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE

import pytest
import skhep_testdata

import uproot


def test_atomic():
filename = skhep_testdata.data_path("test_ntuple_atomic_bitset.root")
with uproot.open(filename) as f:
obj = f["ntuple"]

a = obj.arrays("atomic_int")

assert a.atomic_int.tolist() == [1, 2, 3]


def test_bitset():
filename = skhep_testdata.data_path("test_ntuple_atomic_bitset.root")
with uproot.open(filename) as f:
obj = f["ntuple"]

a = obj.arrays("bitset")

assert len(a.bitset) == 3
assert len(a.bitset[0]) == 42
assert a.bitset[0].tolist()[:6] == [0, 1, 0, 1, 0, 1]
assert all(a.bitset[0][6:] == 0)
assert a.bitset[1].tolist()[:16] == [
0,
1,
0,
1,
0,
1,
0,
1,
0,
1,
0,
1,
0,
1,
0,
1,
]
assert all(a.bitset[1][16:] == 0)
assert a.bitset[2].tolist()[:16] == [
0,
0,
0,
1,
0,
0,
0,
1,
0,
0,
0,
1,
0,
0,
0,
1,
]
assert all(a.bitset[2][16:] == 0)


def test_empty_struct():
filename = skhep_testdata.data_path("test_ntuple_emptystruct_invalidvar.root")
with uproot.open(filename) as f:
obj = f["ntuple"]

a = obj.arrays("empty_struct")

assert a.empty_struct.tolist() == [{}, {}, {}]


def test_invalid_variant():
filename = skhep_testdata.data_path("test_ntuple_emptystruct_invalidvar.root")
with uproot.open(filename) as f:
obj = f["ntuple"]

a = obj.arrays("variant")

assert a.variant.tolist() == [1, 1, None]

0 comments on commit 80e7803

Please sign in to comment.