diff --git a/xrft/tests/test_xrft.py b/xrft/tests/test_xrft.py index f63144c9..ae248deb 100644 --- a/xrft/tests/test_xrft.py +++ b/xrft/tests/test_xrft.py @@ -1344,3 +1344,21 @@ def test_nondim_coords(): xrft.power_spectrum(da) xrft.power_spectrum(da, dim=["time", "y"]) + + +def test_non_numerical_or_datetime_coords(): + """Error should be raised if there are non-numerical or non-datetime coordinate""" + da = xr.DataArray( + np.random.rand(2, 5, 3), + dims=["time", "x", "y"], + coords={ + "time": np.array(["2019-04-18", "2019-04-19"], dtype="datetime64"), + "x": range(5), + "y": ["a", "b", "c"], + }, + ) + + with pytest.raises(ValueError): + xrft.power_spectrum(da) + + xrft.power_spectrum(da, dim=["time", "x"]) diff --git a/xrft/xrft.py b/xrft/xrft.py index 30d738e8..9ddb3c19 100644 --- a/xrft/xrft.py +++ b/xrft/xrft.py @@ -15,7 +15,7 @@ import scipy.linalg as spl from .detrend import detrend as _detrend - +from pandas.api.types import is_numeric_dtype, is_datetime64_any_dtype __all__ = [ "fft", @@ -230,9 +230,9 @@ def _lag_coord(coord): decoded_time = cftime.date2num(lag, ref_units, calendar) return decoded_time elif pd.api.types.is_datetime64_dtype(v0): - return lag.astype("timedelta64[s]").astype("f8").data + return lag.astype("timedelta64[s]").astype("f8") else: - return lag.data + return lag def dft( @@ -330,7 +330,6 @@ def fft( daft : `xarray.DataArray` The output of the Fourier transformation, with appropriate dimensions. """ - if dim is None: dim = list(da.dims) else: @@ -352,6 +351,20 @@ def fft( real_dim ] # real dim has to be moved or added at the end ! + if not np.all( + [ + ( + is_numeric_dtype(da.coords[d]) + or is_datetime64_any_dtype(da.coords[d]) + or bool(getattr(da.coords[d][0].item(), "calendar", False)) + ) + for d in dim + ] + ): # checking if coodinates are numerical or datetime + raise ValueError( + "All transformed dimensions coordinates must be numerical or datetime." + ) + if chunks_to_segments: da = _stack_chunks(da, dim) @@ -452,7 +465,7 @@ def fft( dims=up_dim, coords={up_dim: newcoords[up_dim]}, ) # taking advantage of xarray broadcasting and ordered coordinates - daft[up_dim].attrs.update({"direct_lag": lag.obj}) + daft[up_dim].attrs.update({"direct_lag": lag}) if true_amplitude: daft = daft * np.prod(delta_x) @@ -520,7 +533,6 @@ def ifft( da : `xarray.DataArray` The output of the Inverse Fourier transformation, with appropriate dimensions. """ - if dim is None: dim = list(daft.dims) else: @@ -540,6 +552,21 @@ def ifft( dim = [d for d in dim if d != real_dim] + [ real_dim ] # real dim has to be moved or added at the end ! + + if not np.all( + [ + ( + is_numeric_dtype(daft.coords[d]) + or is_datetime64_any_dtype(daft.coords[d]) + or bool(getattr(daft.coords[d][0].item(), "calendar", False)) + ) + for d in dim + ] + ): # checking if coodinates are numerical or datetime + raise ValueError( + "All transformed dimensions coordinates must be numerical or datetime." + ) + if lag is None: lag = [daft[d].attrs.get("direct_lag", 0.0) for d in dim] msg = "Default ifft's behaviour (lag=None) changed! Default value of lag was zero (centered output coordinates) and is now set to transformed coordinate's attribute: 'direct_lag'." @@ -898,8 +925,10 @@ def cross_phase(da1, da2, dim=None, true_phase=True, **kwargs): kwargs : dict : see xrft.fft for argument list """ - cp = xr.ufuncs.angle( - cross_spectrum(da1, da2, dim=dim, true_phase=true_phase, **kwargs) + cp = xr.apply_ufunc( + np.angle, + cross_spectrum(da1, da2, dim=dim, true_phase=true_phase, **kwargs), + dask="allowed", ) if da1.name and da2.name: