diff --git a/pyxem/signals/polar_diffraction2d.py b/pyxem/signals/polar_diffraction2d.py index f815ca4a8..88e7fa841 100644 --- a/pyxem/signals/polar_diffraction2d.py +++ b/pyxem/signals/polar_diffraction2d.py @@ -24,6 +24,11 @@ from pyxem.utils._correlations import _correlation, _power, _pearson_correlation from pyxem.utils._deprecated import deprecated +from pyxem.utils._background_subtraction import ( + _polar_subtract_radial_median, + _polar_subtract_radial_percentile, +) + class PolarDiffraction2D(CommonDiffraction, Signal2D): """Signal class for two-dimensional diffraction data in polar coordinates. @@ -73,7 +78,7 @@ def get_angular_correlation( mask=mask, normalize=normalize, inplace=inplace, - **kwargs + **kwargs, ) s = self if inplace else correlation theta_axis = s.axes_manager.signal_axes[0] @@ -256,6 +261,43 @@ def get_resolved_pearson_correlation( return correlation + def subtract_diffraction_background( + self, method="radial median", inplace=False, **kwargs + ): + """Background subtraction of the diffraction data. + + Parameters + ---------- + method : str, optional + 'radial median', 'radial percentile' + Default 'radial median'. + + For 'radial median' no extra parameters are necessary. + + For 'radial percentile' the 'percentile' argument decides + which percentile to substract. + **kwargs : + To be passed to the chosen method. + + Returns + ------- + s : PolarDiffraction2D or LazyPolarDiffraction2D signal + + """ + method_dict = { + "radial median": _polar_subtract_radial_median, + "radial percentile": _polar_subtract_radial_percentile, + } + if method not in method_dict: + raise NotImplementedError( + f"The method specified, '{method}'," + f" is not implemented. The different methods are: " + f"{', '.join(method_dict.keys())}." + ) + subtraction_function = method_dict[method] + + return self.map(subtraction_function, inplace=inplace, **kwargs) + class LazyPolarDiffraction2D(LazySignal, PolarDiffraction2D): pass diff --git a/pyxem/tests/signals/test_polar_diffraction2d.py b/pyxem/tests/signals/test_polar_diffraction2d.py index 8d8bc530f..639b7ec1c 100644 --- a/pyxem/tests/signals/test_polar_diffraction2d.py +++ b/pyxem/tests/signals/test_polar_diffraction2d.py @@ -249,3 +249,31 @@ def test_decomposition_class_assignment(self, diffraction_pattern): s = PolarDiffraction2D(diffraction_pattern) s.decomposition() assert isinstance(s, PolarDiffraction2D) + + +class TestSubtractingDiffractionBackground: + @pytest.fixture + def noisy_data(self): + data = np.random.rand(3, 2, 20, 15) + data[:, :, 10:12, 7:9] = 100 + dp = PolarDiffraction2D(data) + return dp + + @pytest.mark.parametrize( + ["method", "kwargs"], + [ + ("radial median", {}), + ("radial percentile", {"percentile": 40}), + pytest.param( + "this method does not exist", + {}, + marks=pytest.mark.xfail(raises=NotImplementedError), + ), + ], + ) + def test_subtract_backgrounds(self, method, kwargs, noisy_data): + kwargs["inplace"] = False + + subtracted = noisy_data.subtract_diffraction_background(method=method, **kwargs) + assert isinstance(subtracted, PolarDiffraction2D) + assert subtracted.data.shape == noisy_data.data.shape diff --git a/pyxem/utils/_background_subtraction.py b/pyxem/utils/_background_subtraction.py index f8322fbc8..2f202371a 100644 --- a/pyxem/utils/_background_subtraction.py +++ b/pyxem/utils/_background_subtraction.py @@ -100,3 +100,29 @@ def _subtract_hdome(frame, **kwargs): ) bg_subtracted = bg_subtracted / np.max(bg_subtracted) return bg_subtracted + + +def _polar_subtract_radial_median(frame): + """Background removal using the radial median""" + median = np.nanmedian(frame, axis=1) + image = frame - median[:, np.newaxis] + image[image < 0] = 0 + return image + + +def _polar_subtract_radial_percentile(frame, percentile: int): + """Background removal using the specified radial percentile. + + Parameters + ---------- + frame : NumPy 2D array + percentile : percentile, between 0 and 100 + + Note + ------- + if `percentile` is 50, then this is equivalent to `_polar_subtract_radial_median`. + """ + percentile = np.nanpercentile(frame, percentile, axis=1) + image = frame - percentile[:, np.newaxis] + image[image < 0] = 0 + return image