Skip to content

Commit

Permalink
Fix for aggregators in evt to index output with evt_idx instead o…
Browse files Browse the repository at this point in the history
…f `ch_idx` (#551)

* fix for first to last for cal data where different rows for each table
* fix filedb to use new lgdo
* evaluate_at_channel_vov does not need cumulength argument, fixed cumulengths _> cumulengt, add cumulength to evaluate_to_vector and fix int channels
* searchsorted needs to be 'right' to match cumulative lengths, updated evaluate_at_channel_vov to use evt_ids_ch
  • Loading branch information
ggmarshall authored Jan 31, 2024
1 parent 724655c commit 79d47bd
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
54 changes: 43 additions & 11 deletions src/pygama/evt/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


def evaluate_to_first_or_last(
cumulength: NDArray,
idx: NDArray,
ids: NDArray,
f_hit: str,
Expand Down Expand Up @@ -86,6 +87,11 @@ def evaluate_to_first_or_last(
for ch in chns:
# get index list for this channel to be loaded
idx_ch = idx[ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)]
evt_ids_ch = np.searchsorted(
cumulength,
np.where(ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch))[0],
"right",
)

# evaluate at channel
res = utils.get_data_at_channel(
Expand Down Expand Up @@ -131,18 +137,27 @@ def evaluate_to_first_or_last(
if ch == chns[0]:
outt[:] = np.inf

out[idx_ch] = np.where((t0 < outt) & (limarr), res, out[idx_ch])
outt[idx_ch] = np.where((t0 < outt) & (limarr), t0, outt[idx_ch])
out[evt_ids_ch] = np.where(
(t0 < outt[evt_ids_ch]) & (limarr), res, out[evt_ids_ch]
)
outt[evt_ids_ch] = np.where(
(t0 < outt[evt_ids_ch]) & (limarr), t0, outt[evt_ids_ch]
)

else:
out[idx_ch] = np.where((t0 > outt) & (limarr), res, out[idx_ch])
outt[idx_ch] = np.where((t0 > outt) & (limarr), t0, outt[idx_ch])
out[evt_ids_ch] = np.where(
(t0 > outt[evt_ids_ch]) & (limarr), res, out[evt_ids_ch]
)
outt[evt_ids_ch] = np.where(
(t0 > outt[evt_ids_ch]) & (limarr), t0, outt[evt_ids_ch]
)

return Array(nda=out)


def evaluate_to_scalar(
mode: str,
cumulength: NDArray,
idx: NDArray,
ids: NDArray,
f_hit: str,
Expand Down Expand Up @@ -207,6 +222,11 @@ def evaluate_to_scalar(
for ch in chns:
# get index list for this channel to be loaded
idx_ch = idx[ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)]
evt_ids_ch = np.searchsorted(
cumulength,
np.where(ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch))[0],
"right",
)

res = utils.get_data_at_channel(
ch=ch,
Expand Down Expand Up @@ -241,20 +261,21 @@ def evaluate_to_scalar(
if "sum" == mode:
if res.dtype == bool:
res = res.astype(int)
out[idx_ch] = np.where(limarr, res + out[idx_ch], out[idx_ch])
out[evt_ids_ch] = np.where(limarr, res + out[evt_ids_ch], out[evt_ids_ch])
if "any" == mode:
if res.dtype != bool:
res = res.astype(bool)
out[idx_ch] = out[idx_ch] | (res & limarr)
out[evt_ids_ch] = out[evt_ids_ch] | (res & limarr)
if "all" == mode:
if res.dtype != bool:
res = res.astype(bool)
out[idx_ch] = out[idx_ch] & res & limarr
out[evt_ids_ch] = out[evt_ids_ch] & res & limarr

return Array(nda=out)


def evaluate_at_channel(
cumulength: NDArray,
idx: NDArray,
ids: NDArray,
f_hit: str,
Expand Down Expand Up @@ -314,6 +335,7 @@ def evaluate_at_channel(
):
continue
idx_ch = idx[ids == ch]
evt_ids_ch = np.searchsorted(cumulength, np.where(ids == ch)[0], "right")
res = utils.get_data_at_channel(
ch=utils.get_table_name_by_pattern(tcm_id_table_pattern, ch),
ids=ids,
Expand All @@ -332,12 +354,13 @@ def evaluate_at_channel(
dsp_group=dsp_group,
)

out[idx_ch] = np.where(ch == ch_comp.nda[idx_ch], res, out[idx_ch])
out[evt_ids_ch] = np.where(ch == ch_comp.nda[idx_ch], res, out[evt_ids_ch])

return Array(nda=out)


def evaluate_at_channel_vov(
cumulength: NDArray,
idx: NDArray,
ids: NDArray,
f_hit: str,
Expand Down Expand Up @@ -397,7 +420,7 @@ def evaluate_at_channel_vov(

type_name = None
for ch in chns:
idx_ch = idx[ids == ch]
evt_ids_ch = np.searchsorted(cumulength, np.where(ids == ch)[0], "right")
res = utils.get_data_at_channel(
ch=utils.get_table_name_by_pattern(tcm_id_table_pattern, ch),
ids=ids,
Expand All @@ -419,7 +442,7 @@ def evaluate_at_channel_vov(
# see in which events the current channel is present
mask = ak.to_numpy(ak.any(ch_comp == ch, axis=-1), allow_missing=False)
cv = np.full(len(ch_comp), np.nan)
cv[idx_ch] = res
cv[evt_ids_ch] = res
cv[~mask] = np.nan
cv = ak.drop_none(ak.nan_to_none(ak.Array(cv)[:, None]))

Expand All @@ -432,6 +455,7 @@ def evaluate_at_channel_vov(


def evaluate_to_aoesa(
cumulength: NDArray,
idx: NDArray,
ids: NDArray,
f_hit: str,
Expand Down Expand Up @@ -501,6 +525,11 @@ def evaluate_to_aoesa(
i = 0
for ch in chns:
idx_ch = idx[ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch)]
evt_ids_ch = np.searchsorted(
cumulength,
np.where(ids == utils.get_tcm_id_by_pattern(tcm_id_table_pattern, ch))[0],
"right",
)
res = utils.get_data_at_channel(
ch=ch,
ids=ids,
Expand Down Expand Up @@ -530,14 +559,15 @@ def evaluate_to_aoesa(
dsp_group=dsp_group,
)

out[idx_ch, i] = np.where(limarr, res, out[idx_ch, i])
out[evt_ids_ch, i] = np.where(limarr, res, out[evt_ids_ch, i])

i += 1

return ArrayOfEqualSizedArrays(nda=out)


def evaluate_to_vector(
cumulength: NDArray,
idx: NDArray,
ids: NDArray,
f_hit: str,
Expand Down Expand Up @@ -602,6 +632,7 @@ def evaluate_to_vector(
LH5 root group in `evt` file.
"""
out = evaluate_to_aoesa(
cumulength=cumulength,
idx=idx,
ids=ids,
f_hit=f_hit,
Expand All @@ -625,6 +656,7 @@ def evaluate_to_vector(
if sorter is not None:
md, fld = sorter.split(":")
s_val = evaluate_to_aoesa(
cumulength=cumulength,
idx=idx,
ids=ids,
f_hit=f_hit,
Expand Down
8 changes: 8 additions & 0 deletions src/pygama/evt/build_evt.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ def evaluate_expression(
# load TCM data to define an event
ids = store.read(f"/{tcm_group}/array_id", f_tcm)[0].view_as("np")
idx = store.read(f"/{tcm_group}/array_idx", f_tcm)[0].view_as("np")
cumulength = store.read(f"/{tcm_group}/cumulative_length", f_tcm)[0].view_as(
"np"
)

# switch through modes
if table and (("keep_at_ch:" == mode[:11]) or ("keep_at_idx:" == mode[:12])):
Expand All @@ -466,6 +469,7 @@ def evaluate_expression(

if isinstance(ch_comp, Array):
return aggregators.evaluate_at_channel(
cumulength=cumulength,
idx=idx,
ids=ids,
f_hit=f_hit,
Expand All @@ -483,6 +487,7 @@ def evaluate_expression(
)
elif isinstance(ch_comp, VectorOfVectors):
return aggregators.evaluate_at_channel_vov(
cumulength=cumulength,
idx=idx,
ids=ids,
f_hit=f_hit,
Expand Down Expand Up @@ -511,6 +516,7 @@ def evaluate_expression(
)[0]
)
return aggregators.evaluate_to_first_or_last(
cumulength=cumulength,
idx=idx,
ids=ids,
f_hit=f_hit,
Expand All @@ -533,6 +539,7 @@ def evaluate_expression(
elif mode in ["sum", "any", "all"]:
return aggregators.evaluate_to_scalar(
mode=mode,
cumulength=cumulength,
idx=idx,
ids=ids,
f_hit=f_hit,
Expand All @@ -552,6 +559,7 @@ def evaluate_expression(
)
elif "gather" == mode:
return aggregators.evaluate_to_vector(
cumulength=cumulength,
idx=idx,
ids=ids,
f_hit=f_hit,
Expand Down
3 changes: 2 additions & 1 deletion src/pygama/flow/file_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import h5py
import numpy as np
import pandas as pd
from lgdo.lh5.store import LH5Store, ls
from lgdo.lh5 import ls
from lgdo.lh5.store import LH5Store
from lgdo.lh5.utils import expand_path, expand_vars
from lgdo.types import Array, Scalar, VectorOfVectors
from parse import parse
Expand Down

0 comments on commit 79d47bd

Please sign in to comment.