Skip to content

Commit

Permalink
Implement some polar-specific background substraction algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
viljarjf committed Apr 21, 2024
1 parent 5c6a458 commit 32dcbbf
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 1 deletion.
44 changes: 43 additions & 1 deletion pyxem/signals/polar_diffraction2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
from pyxem.utils._correlations import _correlation, _power, _pearson_correlation
from pyxem.utils._deprecated import deprecated

from pyxem.utils._background_subtraction import (
_polar_subtract_radial_median,
_polar_subtract_radial_percentile,
)


class PolarDiffraction2D(CommonDiffraction, Signal2D):
"""Signal class for two-dimensional diffraction data in polar coordinates.
Expand Down Expand Up @@ -73,7 +78,7 @@ def get_angular_correlation(
mask=mask,
normalize=normalize,
inplace=inplace,
**kwargs
**kwargs,
)
s = self if inplace else correlation
theta_axis = s.axes_manager.signal_axes[0]
Expand Down Expand Up @@ -256,6 +261,43 @@ def get_resolved_pearson_correlation(

return correlation

def subtract_diffraction_background(
self, method="radial median", inplace=False, **kwargs
):
"""Background subtraction of the diffraction data.
Parameters
----------
method : str, optional
'radial median', 'radial percentile'
Default 'radial median'.
For 'radial median' no extra parameters are necessary.
For 'radial percentile' the 'percentile' argument decides
which percentile to substract.
**kwargs :
To be passed to the chosen method.
Returns
-------
s : PolarDiffraction2D or LazyPolarDiffraction2D signal
"""
method_dict = {
"radial median": _polar_subtract_radial_median,
"radial percentile": _polar_subtract_radial_percentile,
}
if method not in method_dict:
raise NotImplementedError(
f"The method specified, '{method}',"
f" is not implemented. The different methods are: "
f"{', '.join(method_dict.keys())}."
)
subtraction_function = method_dict[method]

return self.map(subtraction_function, inplace=inplace, **kwargs)


class LazyPolarDiffraction2D(LazySignal, PolarDiffraction2D):
pass
28 changes: 28 additions & 0 deletions pyxem/tests/signals/test_polar_diffraction2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,31 @@ def test_decomposition_class_assignment(self, diffraction_pattern):
s = PolarDiffraction2D(diffraction_pattern)
s.decomposition()
assert isinstance(s, PolarDiffraction2D)


class TestSubtractingDiffractionBackground:
@pytest.fixture
def noisy_data(self):
data = np.random.rand(3, 2, 20, 15)
data[:, :, 10:12, 7:9] = 100
dp = PolarDiffraction2D(data)
return dp

@pytest.mark.parametrize(
["method", "kwargs"],
[
("radial median", {}),
("radial percentile", {"percentile": 40}),
pytest.param(
"this method does not exist",
{},
marks=pytest.mark.xfail(raises=NotImplementedError),
),
],
)
def test_subtract_backgrounds(self, method, kwargs, noisy_data):
kwargs["inplace"] = False

subtracted = noisy_data.subtract_diffraction_background(method=method, **kwargs)
assert isinstance(subtracted, PolarDiffraction2D)
assert subtracted.data.shape == noisy_data.data.shape
26 changes: 26 additions & 0 deletions pyxem/utils/_background_subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,29 @@ def _subtract_hdome(frame, **kwargs):
)
bg_subtracted = bg_subtracted / np.max(bg_subtracted)
return bg_subtracted


def _polar_subtract_radial_median(frame):
"""Background removal using the radial median"""
median = np.nanmedian(frame, axis=1)
image = frame - median[:, np.newaxis]
image[image < 0] = 0
return image


def _polar_subtract_radial_percentile(frame, percentile: int):
"""Background removal using the specified radial percentile.
Parameters
----------
frame : NumPy 2D array
percentile : percentile, between 0 and 100
Note
-------
if `percentile` is 50, then this is equivalent to `_polar_subtract_radial_median`.
"""
percentile = np.nanpercentile(frame, percentile, axis=1)
image = frame - percentile[:, np.newaxis]
image[image < 0] = 0
return image

0 comments on commit 32dcbbf

Please sign in to comment.