Skip to content

Commit

Permalink
Merge pull request pyxem#935 from CSSFrancis/filter_data
Browse files Browse the repository at this point in the history
Add nd-filtering
  • Loading branch information
CSSFrancis authored Nov 6, 2023
2 parents a88f82d + 9706224 commit 67d51c3
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ Fixed
- Set skimage != to version 0.21.0 because of regression
- Do not reverse the y-axis of diffraction patterns when template matching (#925)

Added
-----
- Add n-d and 2-d filters #935 for filtering datasets

2023-05-08 - version 0.15.1
===========================
Expand Down
72 changes: 72 additions & 0 deletions examples/processing/filtering_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Filtering Data
==============
If you have a low number of counts in your data, you may want to filter the data
to remove noise. This can be done using the `filter` function which applies some
function to the entire dataset and returns a filtered dataset of the same shape.
"""

from scipy.ndimage import gaussian_filter
from dask_image.ndfilters import gaussian_filter as dask_gaussian_filter
import pyxem as pxm
import hyperspy.api as hs
import numpy as np

s = pxm.data.mgo_nanocrystals(allow_download=True) # MgO nanocrystals dataset

s_filtered = s.filter(
gaussian_filter, sigma=1.0, inplace=False
) # Gaussian filter with sigma=1.0

s_filtered2 = s.filter(
gaussian_filter, sigma=(1.0, 1.0, 0, 0), inplace=False
) # Only filter in real space

hs.plot.plot_images(
[s.inav[10, 10], s_filtered.inav[10, 10], s_filtered2.inav[10, 10]],
label=["Original", "GaussFilt(all)", "GaussFilt(real space)"],
tight_layout=True,
vmax="99th",
)

# %%
"""
The `filter` function can also be used with a custom function as long as the function
takes a numpy array as input and returns a numpy array of the same shape.
"""


def custom_filter(array):
filtered = gaussian_filter(array, sigma=1.0)
return filtered - np.mean(filtered)


s_filtered3 = s.filter(custom_filter, inplace=False) # Custom filter

hs.plot.plot_images(
[s.inav[10, 10], s_filtered3.inav[10, 10]],
label=["Original", "GaussFilt(Custom)"],
tight_layout=True,
vmax="99th",
)
# %%

"""
For lazy datasets, functions which operate on dask arrays can be used. For example,
the `gaussian_filter` function from `scipy.ndimage` is replaced with the `dask_image`
version which operates on dask arrays.
"""

s = s.as_lazy() # Convert to lazy dataset
s_filtered4 = s.filter(
dask_gaussian_filter, sigma=1.0, inplace=False
) # Gaussian filter with sigma=1.0

hs.plot.plot_images(
[s_filtered.inav[10, 10], s_filtered4.inav[10, 10]],
label=["GaussFilt", "GaussFilt(Lazy)"],
tight_layout=True,
vmax="99th",
)
# %%
44 changes: 41 additions & 3 deletions pyxem/signals/diffraction2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@


import numpy as np
from skimage import filters
from skimage.feature import match_template
from scipy.ndimage import rotate
from skimage import morphology
import dask.array as da
Expand All @@ -30,8 +28,8 @@
import hyperspy.api as hs
from hyperspy.signals import Signal2D, BaseSignal
from hyperspy._signals.lazy import LazySignal
from hyperspy._signals.signal2d import LazySignal2D
from hyperspy.misc.utils import isiterable
from importlib import import_module

from pyxem.signals import (
CommonDiffraction,
Expand Down Expand Up @@ -1124,6 +1122,45 @@ def template_match_ring(self, r_inner=5, r_outer=7, inplace=False, **kwargs):
normalize_template_match, template=ring, inplace=inplace, **kwargs
)

def filter(self, func, inplace=False, **kwargs):
"""Filters the entire dataset given some function applied to the data.
The function must take a numpy or dask array as input and return a
numpy or dask array as output which has the same shape, and axes as
the input.
Parameters
----------
func : function
Function to apply to the data. Must take a numpy or dask array as
input and return a numpy or dask array as output which has the
same shape as the input.
inplace : bool, optional
If True, the data is replaced by the filtered data. If False, a
new signal is returned. Default False.
**kwargs :
Passed to the function.
Examples
--------
>>> import pyxem as pxm
>>> from scipy.ndimage import gaussian_filter
>>> s = pxm.dummy_data.get_cbed_signal()
>>> s_filtered = s.filter(gaussian_filter, sigma=1)
"""
new_data = func(self.data, **kwargs)

if new_data.shape != self.data.shape:
raise ValueError(
"The function must return an array with " "the same shape as the input."
)
if inplace:
self.data = new_data
return
else:
return self._deepcopy_with_new_data(data=new_data)

def template_match(self, template, inplace=False, **kwargs):
"""Template match the signal dimensions with a binary image.
Expand All @@ -1143,6 +1180,7 @@ def template_match(self, template, inplace=False, **kwargs):
Examples
--------
>>> import pyxem as pxm
>>> s = pxm.dummy_data.get_cbed_signal()
>>> binary_image = np.random.randint(0, 2, (6, 6))
>>> s_template = s.template_match_with_binary_image(
Expand Down
34 changes: 34 additions & 0 deletions pyxem/tests/signals/test_diffraction2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from matplotlib import pyplot as plt
from numpy.random import default_rng
from skimage.draw import circle_perimeter_aa
import scipy

from pyxem.signals import (
Diffraction1D,
Expand Down Expand Up @@ -1367,3 +1368,36 @@ def test_wrong_navigator_shape_kwarg(self):
s_nav = Diffraction2D(np.zeros((2, 19)))
s._navigator_probe = s_nav
s.plot()


class TestFilter:
@pytest.fixture
def three_section(self):
x = np.random.random((100, 50, 20, 20))
x[0:20, :, 5:7, 5:7] = x[0:20, :, 5:7, 5:7] + 10
x[20:60, :, 1:3, 14:16] = x[20:60, :, 1:3, 14:16] + 10
x[60:100, :, 6:8, 10:12] = x[60:100, :, 6:8, 10:12] + 10
d = Diffraction2D(x)
return d

@pytest.mark.parametrize("lazy", [True, False])
def test_filter(self, three_section, lazy):
if lazy: # pragma: no cover
dask_image = pytest.importorskip("dask_image")
from dask_image.ndfilters import gaussian_filter as gaussian_filter

three_section = three_section.as_lazy()
else:
from scipy.ndimage import gaussian_filter

sigma = (3, 3, 3, 3)
new = three_section.filter(func=gaussian_filter, sigma=sigma, inplace=False)
three_section.filter(func=gaussian_filter, sigma=sigma, inplace=True)
np.testing.assert_array_almost_equal(new.data, three_section.data)

def test_filter_fail(self, three_section):
def small_func(x):
return x[1:, 1:, 1:, 1:]

with pytest.raises(ValueError):
new = three_section.filter(func=small_func, inplace=False)
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"sphinx-codeautolink",
"pydata-sphinx-theme",
"hyperspy-gui-ipywidgets",
"dask-image",
],
"tests": [
"pytest >= 5.0",
Expand All @@ -47,6 +48,7 @@
],
"dev": ["black", "pre-commit >=1.16"],
"gpu": ["cupy >= 9.0.0"],
"dask": ["dask-image", "distributed"],
}


Expand Down

0 comments on commit 67d51c3

Please sign in to comment.