Skip to content

Commit

Permalink
adress comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RuneDominik committed Apr 8, 2024
1 parent 584ca51 commit 26a9d8b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
20 changes: 12 additions & 8 deletions pyirf/interpolation/component_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,12 @@ def __init__(
else:
self.fill_val = fill_value

# Raise error if fill-values should be handled in >=3 dims
if self.fill_val and self.grid_dim >= 3:
raise ValueError(
"Fill-value handling only supported in up to two grid dimensions."
)

# If fill-values should be handled in 2D, construct a trinangulation
# to later determine in which simplex the target values is
if self.fill_val and (self.grid_dim == 2):
Expand All @@ -566,19 +572,15 @@ def __call__(self, target_point):
# First, construct estimation without handling fill-values
full_estimation = super().__call__(target_point)
# Safeguard against extreme extrapolation cases
full_estimation[full_estimation < 0] = 0
np.clip(full_estimation, 0, None, out=full_estimation)

# Early exit if fill_values should not be handled
if not self.fill_val:
return full_estimation

# Raise error if fill-values should be handled in >=3 dims
if self.grid_dim >= 3:
raise ValueError(
"Fill-value handling only supported in up to two grid dimensions."
)

# Early exit if a nearest neighbor estimation would be overwritten
# Complex setup is needed to catch settings where the user mixes approaches and
# e.g. uses nearest neighbors for extrapolation and an actual interpolation otherwise
if self.grid_dim == 1:
if (
(target_point < self.grid_points.min())
Expand Down Expand Up @@ -631,7 +633,9 @@ def __call__(self, target_point):

# This collected mask now counts for each entry in the estimation how many
# of the entries used for extrapolation contained fill-values
intermediate_mask = mask0.astype("int") + mask1.astype("int") + mask2.astype("int")
intermediate_mask = (
mask0.astype("int") + mask1.astype("int") + mask2.astype("int")
)
mask = np.full_like(intermediate_mask, True, dtype=bool)

# Simplest cases: All or none entries were fill-values, so either return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,18 +358,16 @@ def test_RadMaxEstimator_fill_val_handling_3D():

rad_max = np.array([[0.95], [0.95], [0.95], [0.95]])

estim = RadMaxEstimator(
grid_points=grid_points_3D,
rad_max=rad_max,
fill_value=0.95,
interpolator_cls=GridDataInterpolator,
interpolator_kwargs=None,
extrapolator_cls=None,
extrapolator_kwargs=None,
)

with pytest.raises(
ValueError,
match="Fill-value handling only supported in up to two grid dimensions.",
):
estim(np.array([0.25, 0.25, 0.25]))
RadMaxEstimator(
grid_points=grid_points_3D,
rad_max=rad_max,
fill_value=0.95,
interpolator_cls=GridDataInterpolator,
interpolator_kwargs=None,
extrapolator_cls=None,
extrapolator_kwargs=None,
)

0 comments on commit 26a9d8b

Please sign in to comment.