From 24b0e1d2cd1a40c65c3e950f4a39fae108cb2a90 Mon Sep 17 00:00:00 2001 From: Josh Borrow Date: Fri, 25 Oct 2024 09:55:22 -0400 Subject: [PATCH] Change FFT return units to be same as input --- unyt/_array_functions.py | 28 ++++++++++++++-------------- unyt/tests/test_array_functions.py | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 4b22dd20..646629e2 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -358,72 +358,72 @@ def block(arrays): @implements(np.fft.fft) def ftt_fft(a, *args, **kwargs): - return np.fft.fft._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.fft._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.fft2) def ftt_fft2(a, *args, **kwargs): - return np.fft.fft2._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.fft2._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.fftn) def ftt_fftn(a, *args, **kwargs): - return np.fft.fftn._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.fftn._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.hfft) def ftt_hfft(a, *args, **kwargs): - return np.fft.hfft._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.hfft._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.rfft) def ftt_rfft(a, *args, **kwargs): - return np.fft.rfft._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.rfft._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.rfft2) def ftt_rfft2(a, *args, **kwargs): - return np.fft.rfft2._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.rfft2._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.rfftn) def ftt_rfftn(a, *args, **kwargs): - return np.fft.rfftn._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.rfftn._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.ifft) def ftt_ifft(a, *args, **kwargs): - return np.fft.ifft._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.ifft._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.ifft2) def ftt_ifft2(a, *args, **kwargs): - return np.fft.ifft2._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.ifft2._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.ifftn) def ftt_ifftn(a, *args, **kwargs): - return np.fft.ifftn._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.ifftn._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.ihfft) def ftt_ihfft(a, *args, **kwargs): - return np.fft.ihfft._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.ihfft._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.irfft) def ftt_irfft(a, *args, **kwargs): - return np.fft.irfft._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.irfft._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.irfft2) def ftt_irfft2(a, *args, **kwargs): - return np.fft.irfft2._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.irfft2._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.irfftn) def ftt_irfftn(a, *args, **kwargs): - return np.fft.irfftn._implementation(np.asarray(a), *args, **kwargs) / a.units + return np.fft.irfftn._implementation(np.asarray(a), *args, **kwargs) * a.units @implements(np.fft.fftshift) diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index 4903e93c..bfd900ab 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -923,7 +923,7 @@ def test_fft_1D(func): x1 = [0, 1, 2] * cm res = func(x1) assert type(res) is unyt_array - assert res.units == (1 / cm).units + assert res.units == (1 * cm).units @pytest.mark.parametrize( @@ -943,7 +943,7 @@ def test_fft_ND(func): x1 = [[0, 1, 2], [0, 1, 2], [0, 1, 2]] * cm res = func(x1) assert type(res) is unyt_array - assert res.units == (1 / cm).units + assert res.units == (1 * cm).units @pytest.mark.parametrize("func", [np.fft.fftshift, np.fft.ifftshift])