Skip to content

Commit

Permalink
Merge pull request #204 from icecube/photospline_update
Browse files Browse the repository at this point in the history
Fall back to `evaluate_simple` in case the `search_centers` call fails.
  • Loading branch information
tomaskontrimas authored Feb 2, 2024
2 parents 3c0cbf8 + 2b3e25d commit 4022be0
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 23 deletions.
42 changes: 30 additions & 12 deletions skyllh/core/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,21 +1430,39 @@ def get_pd_with_eventdata(

if self.basis_function_indices is None:
with TaskTimer(tl, 'Get basis function indices from photospline.'):
self.basis_function_indices = self._pdf.search_centers(
[eventdata[i] for i in range(0, V)]
)
try:
self.basis_function_indices = self._pdf.search_centers(
[eventdata[i] for i in range(0, V)]
)
except ValueError:
# In case the photospline `search_centers` call fails
# (when `eventdata` is outside photospline boundaries)
# we can fall back to the slower photospline evaluation.
logger.info(
"Falling back to the slower photospline evaluation.")

with TaskTimer(tl, 'Get pd from photospline fit.'):
if evt_mask is None:
pd = self._pdf.evaluate(
[eventdata[i] for i in range(0, V)],
[self.basis_function_indices[i] for i in range(0, V)],
)
if self.basis_function_indices is not None:
if evt_mask is None:
pd = self._pdf.evaluate(
[eventdata[i] for i in range(0, V)],
[self.basis_function_indices[i] for i in range(0, V)],
)
else:
pd = self._pdf.evaluate(
[eventdata[i][evt_mask] for i in range(0, V)],
[self.basis_function_indices[i][evt_mask] for i in range(0, V)],
)
else:
pd = self._pdf.evaluate(
[eventdata[i][evt_mask] for i in range(0, V)],
[self.basis_function_indices[i][evt_mask] for i in range(0, V)],
)
# Falling back to the slower photospline evaluation.
if evt_mask is None:
pd = self._pdf.evaluate_simple(
[eventdata[i] for i in range(0, V)],
)
else:
pd = self._pdf.evaluate_simple(
[eventdata[i][evt_mask] for i in range(0, V)],
)

with TaskTimer(tl, 'Normalize MultiDimGridPDF with norm factor.'):
norm = self._norm_factor_func(
Expand Down
30 changes: 19 additions & 11 deletions skyllh/core/signalpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,11 +833,15 @@ def initialize_for_new_trial(
tl,
'Get and set basis function indices for all PDFs.'):
V = self._cache_eventdata.shape[0]
bfi = pdf.pdf.search_centers(
[self._cache_eventdata[i] for i in range(0, V)]
)
for (_, pdf) in self.items():
pdf.basis_function_indices = bfi
try:
bfi = pdf.pdf.search_centers(
[self._cache_eventdata[i] for i in range(0, V)]
)
for (_, pdf) in self.items():
pdf.basis_function_indices = bfi
except ValueError:
logger = get_logger(f'{__name__}.{classname(self)}.initialize_for_new_trial')
logger.info("Falling back to the slower photospline evaluation.")

def get_pd(
self,
Expand Down Expand Up @@ -1054,11 +1058,15 @@ def initialize_for_new_trial(
tl,
'Get and set basis function indices for all PDFs.'):
V = self._cache_eventdata.shape[0]
bfi = pdf.pdf.search_centers(
[self._cache_eventdata[i] for i in range(0, V)]
)
for (_, pdf) in self.items():
pdf.basis_function_indices = bfi
try:
bfi = pdf.pdf.search_centers(
[self._cache_eventdata[i] for i in range(0, V)]
)
for (_, pdf) in self.items():
pdf.basis_function_indices = bfi
except ValueError:
logger = get_logger(f'{__name__}.{classname(self)}.initialize_for_new_trial')
logger.info("Falling back to the slower photospline evaluation.")

def get_pd(
self,
Expand Down Expand Up @@ -1112,7 +1120,7 @@ def get_pd(
pdf_key = self.make_key({'shg_idxs': shg_idxs})
pdf = self.get_pdf(pdf_key)

with TaskTimer(tl, f'Get PD values for PDF of SHG {shg_idx}.'):
with TaskTimer(tl, f'Get PD values for PDF of SHG {shg_idxs}.'):
pd_pdf = pdf.get_pd_with_eventdata(
tdm=tdm,
params_recarray=params_recarray,
Expand Down

0 comments on commit 4022be0

Please sign in to comment.