Skip to content

Commit

Permalink
Merge pull request #580 from legend-exp/spms-updates
Browse files Browse the repository at this point in the history
Remove outdated hacks to make `evt.modules.spms` work with older data production versions
  • Loading branch information
gipert authored May 10, 2024
2 parents 1f9f82f + 9185cfe commit 629e655
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 110 deletions.
9 changes: 5 additions & 4 deletions src/pygama/evt/modules/larveto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import awkward as ak
import numpy as np
import scipy
import scipy.stats
from numpy.typing import ArrayLike


Expand Down Expand Up @@ -37,16 +37,17 @@ def l200_combined_test_stat(
probability for a pulse coming from some uncorrelated physics (uniform
distribution). needed for the LAr scintillation time pdf.
rc_density
array of densities (probabilities) of uncorrelated number of photoelectrons in a 6µs window.
density array of the random coincidence LAr energy distribution (total
energy summed over all channels, in p.e.). Derived from forced trigger
data.
"""
# flatten the data in the last axis (i.e. merge all channels together)
# TODO: implement channel distinction
t0 = ak.flatten(t0, axis=-1)
amp = ak.flatten(amp, axis=-1)

# subtract the HPGe t0 from the SiPM pulse t0s
# HACK: remove 16 when units will be fixed
rel_t0 = 16 * t0 - geds_t0
rel_t0 = t0 - geds_t0

return l200_test_stat(rel_t0, amp, ts_bkg_prob, rc_density)

Expand Down
97 changes: 24 additions & 73 deletions src/pygama/evt/modules/spms.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,6 @@ def make_pulse_data_mask(
drop_empty=False,
)

# HACK: handle units
# HACK: remove me once units are fixed in the dsp tier
if "units" in pulse_t0.attrs and pulse_t0.attrs["units"] == "ns":
pulse_t0_ns = pulse_t0.view_as("ak")
else:
pulse_t0_ns = pulse_t0.view_as("ak") * 16

pulse_amp = gather_pulse_data(
datainfo,
tcm,
Expand All @@ -283,6 +276,7 @@ def make_pulse_data_mask(
t_loc_ns = ak.fill_none(ak.nan_to_none(t_loc_ns), t_loc_default_ns)

# start with all-true mask
pulse_t0_ns = pulse_t0.view_as("ak")
mask = pulse_t0_ns == pulse_t0_ns

# apply p.e. threshold
Expand All @@ -308,6 +302,8 @@ def geds_coincidence_classifier(
tcm: utils.TCMData,
table_names: Sequence[str],
*,
spms_t0: types.VectorOfVectors,
spms_amp: types.VectorOfVectors,
geds_t0_ns: types.Array,
ts_bkg_prob: float,
rc_density: Sequence[float] | None = None,
Expand All @@ -321,71 +317,26 @@ def geds_coincidence_classifier(
----------
datainfo, tcm, table_names
positional arguments automatically supplied by :func:`.build_evt`.
t0
arrival times of pulses in ns, split by channel.
amp
amplitude of pulses in p.e., split by channel.
geds_t0_ns
t0 (ns) of the HPGe signal.
ts_bkg_prob
probability for a pulse coming from some uncorrelated physics (uniform
distribution). needed for the LAr scintillation time pdf.
rc_density
density array of the random coincidence LAr energy distribution (total
energy summed over all channels, in p.e.). Derived from forced trigger
data.
"""
# mask for windowing data around the HPGe t0
pulse_mask = make_pulse_data_mask(
datainfo,
tcm,
table_names,
a_thr_pe=None,
t_loc_ns=geds_t0_ns,
dt_range_ns=(-1_000, 5_000),
t_loc_default_ns=48_000,
)

# we'll need to remove pulses below noise threshold later
is_good_pulse = gather_is_valid_hit(datainfo, tcm, table_names).view_as("ak")

# load the data
data = {}
for k, obs in {"amp": "hit.energy_in_pe", "t0": "hit.trigger_pos"}.items():
all_data = gather_pulse_data(
datainfo,
tcm,
table_names,
observable=obs,
pulse_mask=None,
drop_empty=False,
).view_as("ak")

# remove pulses below noise threshold and outside the HPGe trigger window
data[k] = all_data[is_good_pulse & pulse_mask.view_as("ak")]

# load the channel info
# rawids = spms.gather_tcm_id_data(
# datainfo,
# tcm,
# table_names,
# pulse_mask=pulse_mask,
# drop_empty=True,
# )

# (HPGe) trigger position can vary among events!
if isinstance(geds_t0_ns, types.Array):
geds_t0_ns = geds_t0_ns.view_as("ak")

ts_data = larveto.l200_combined_test_stat(
data["t0"], data["amp"], geds_t0_ns, ts_bkg_prob, rc_density
)

return types.Array(ts_data)


# REMOVE ME: not needed anymore with VectorOfVectors DSP outputs
def gather_is_valid_hit(datainfo, tcm, table_names):
data = {}
for field in ("is_valid_hit", "trigger_pos"):
data[field] = gather_pulse_data(
datainfo,
tcm,
table_names,
observable=f"hit.{field}",
pulse_mask=None,
drop_empty=False,
).view_as("ak")

return types.VectorOfVectors(
data["is_valid_hit"][
ak.local_index(data["is_valid_hit"]) < ak.num(data["trigger_pos"], axis=-1)
]
return types.Array(
larveto.l200_combined_test_stat(
spms_t0.view_as("ak"),
spms_amp.view_as("ak"),
geds_t0_ns.view_as("ak"),
ts_bkg_prob,
rc_density,
)
)
67 changes: 34 additions & 33 deletions tests/evt/test_build_evt.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,50 +123,51 @@ def test_field_nesting(lgnd_test_data, files_config):
assert sorted(evt.sub2.keys()) == ["dummy", "multiplicity"]


def test_spms_module(lgnd_test_data, files_config):
build_evt(
files_config,
config=f"{config_dir}/spms-module-config.yaml",
wo_mode="of",
)
# FIXME: this can't be properly tested until proper testdata is available
# def test_spms_module(lgnd_test_data, files_config):
# build_evt(
# files_config,
# config=f"{config_dir}/spms-module-config.yaml",
# wo_mode="of",
# )

outfile = files_config["evt"][0]
# outfile = files_config["evt"][0]

evt = lh5.read("/evt", outfile)
# evt = lh5.read("/evt", outfile)

t0 = ak.fill_none(ak.nan_to_none(evt.t0.view_as("ak")), 48_000)
tr_pos = evt.trigger_pos.view_as("ak") * 16
assert ak.all(tr_pos > t0 - 30_000)
assert ak.all(tr_pos < t0 + 30_000)
# t0 = ak.fill_none(ak.nan_to_none(evt.t0.view_as("ak")), 48_000)
# tr_pos = evt.trigger_pos.view_as("ak") * 16
# assert ak.all(tr_pos > t0 - 30_000)
# assert ak.all(tr_pos < t0 + 30_000)

mask = evt._pulse_mask
assert isinstance(mask, VectorOfVectors)
assert len(mask) == 10
assert mask.ndim == 3
# mask = evt._pulse_mask
# assert isinstance(mask, VectorOfVectors)
# assert len(mask) == 10
# assert mask.ndim == 3

full = evt.spms_amp_full.view_as("ak")
amp = evt.spms_amp.view_as("ak")
assert ak.all(amp > 0.1)
# full = evt.spms_amp_full.view_as("ak")
# amp = evt.spms_amp.view_as("ak")
# assert ak.all(amp > 0.1)

assert ak.all(full[mask.view_as("ak")] == amp)
# assert ak.all(full[mask.view_as("ak")] == amp)

wo_empty = evt.spms_amp_wo_empty.view_as("ak")
assert ak.all(wo_empty == amp[ak.count(amp, axis=-1) > 0])
# wo_empty = evt.spms_amp_wo_empty.view_as("ak")
# assert ak.all(wo_empty == amp[ak.count(amp, axis=-1) > 0])

rawids = evt.rawid.view_as("ak")
assert rawids.ndim == 2
assert ak.count(rawids) == 30
# rawids = evt.rawid.view_as("ak")
# assert rawids.ndim == 2
# assert ak.count(rawids) == 30

idx = evt.hit_idx.view_as("ak")
assert idx.ndim == 2
assert ak.count(idx) == 30
# idx = evt.hit_idx.view_as("ak")
# assert idx.ndim == 2
# assert ak.count(idx) == 30

rawids_wo_empty = evt.rawid_wo_empty.view_as("ak")
assert ak.count(rawids_wo_empty) == 7
# rawids_wo_empty = evt.rawid_wo_empty.view_as("ak")
# assert ak.count(rawids_wo_empty) == 7

vhit = evt.is_valid_hit.view_as("ak")
vhit.show()
assert ak.all(ak.num(vhit, axis=-1) == ak.num(full, axis=-1))
# vhit = evt.is_valid_hit.view_as("ak")
# vhit.show()
# assert ak.all(ak.num(vhit, axis=-1) == ak.num(full, axis=-1))


def test_vov(lgnd_test_data, files_config):
Expand Down

0 comments on commit 629e655

Please sign in to comment.