Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fall back to evaluate_simple in case the search_centers call fails. #204

Merged
merged 3 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions skyllh/core/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,21 +1430,39 @@ def get_pd_with_eventdata(

if self.basis_function_indices is None:
with TaskTimer(tl, 'Get basis function indices from photospline.'):
self.basis_function_indices = self._pdf.search_centers(
[eventdata[i] for i in range(0, V)]
)
try:
self.basis_function_indices = self._pdf.search_centers(
[eventdata[i] for i in range(0, V)]
)
except ValueError:
# In case the photospline `search_centers` call fails
# (when `eventdata` is outside photospline boundaries)
# we can fall back to the slower photospline evaluation.
logger.info(
"Falling back to the slower photospline evaluation.")

with TaskTimer(tl, 'Get pd from photospline fit.'):
if evt_mask is None:
pd = self._pdf.evaluate(
[eventdata[i] for i in range(0, V)],
[self.basis_function_indices[i] for i in range(0, V)],
)
if self.basis_function_indices is not None:
if evt_mask is None:
pd = self._pdf.evaluate(
[eventdata[i] for i in range(0, V)],
[self.basis_function_indices[i] for i in range(0, V)],
)
else:
pd = self._pdf.evaluate(
[eventdata[i][evt_mask] for i in range(0, V)],
[self.basis_function_indices[i][evt_mask] for i in range(0, V)],
)
else:
pd = self._pdf.evaluate(
[eventdata[i][evt_mask] for i in range(0, V)],
[self.basis_function_indices[i][evt_mask] for i in range(0, V)],
)
# Falling back to the slower photospline evaluation.
if evt_mask is None:
pd = self._pdf.evaluate_simple(
[eventdata[i] for i in range(0, V)],
)
else:
pd = self._pdf.evaluate_simple(
[eventdata[i][evt_mask] for i in range(0, V)],
)

with TaskTimer(tl, 'Normalize MultiDimGridPDF with norm factor.'):
norm = self._norm_factor_func(
Expand Down
30 changes: 19 additions & 11 deletions skyllh/core/signalpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,11 +833,15 @@ def initialize_for_new_trial(
tl,
'Get and set basis function indices for all PDFs.'):
V = self._cache_eventdata.shape[0]
bfi = pdf.pdf.search_centers(
[self._cache_eventdata[i] for i in range(0, V)]
)
for (_, pdf) in self.items():
pdf.basis_function_indices = bfi
try:
bfi = pdf.pdf.search_centers(
[self._cache_eventdata[i] for i in range(0, V)]
)
for (_, pdf) in self.items():
pdf.basis_function_indices = bfi
except ValueError:
logger = get_logger(f'{__name__}.{classname(self)}.initialize_for_new_trial')
logger.info("Falling back to the slower photospline evaluation.")

def get_pd(
self,
Expand Down Expand Up @@ -1054,11 +1058,15 @@ def initialize_for_new_trial(
tl,
'Get and set basis function indices for all PDFs.'):
V = self._cache_eventdata.shape[0]
bfi = pdf.pdf.search_centers(
[self._cache_eventdata[i] for i in range(0, V)]
)
for (_, pdf) in self.items():
pdf.basis_function_indices = bfi
try:
bfi = pdf.pdf.search_centers(
[self._cache_eventdata[i] for i in range(0, V)]
)
for (_, pdf) in self.items():
pdf.basis_function_indices = bfi
except ValueError:
logger = get_logger(f'{__name__}.{classname(self)}.initialize_for_new_trial')
logger.info("Falling back to the slower photospline evaluation.")

def get_pd(
self,
Expand Down Expand Up @@ -1112,7 +1120,7 @@ def get_pd(
pdf_key = self.make_key({'shg_idxs': shg_idxs})
pdf = self.get_pdf(pdf_key)

with TaskTimer(tl, f'Get PD values for PDF of SHG {shg_idx}.'):
with TaskTimer(tl, f'Get PD values for PDF of SHG {shg_idxs}.'):
pd_pdf = pdf.get_pd_with_eventdata(
tdm=tdm,
params_recarray=params_recarray,
Expand Down
Loading