diff --git a/.travis.yml b/.travis.yml index 39c8ea2..bdc8b78 100644 --- a/.travis.yml +++ b/.travis.yml @@ -36,7 +36,7 @@ before_install: conda create --name TEST python=$PY --file requirements-dev.txt --quiet source activate TEST # Install after to ensure it will be downgraded when testing an older version. - conda install numpy=$NUMPY + conda install numpy=$NUMPY xarray dask conda info --all # Test source distribution. @@ -64,7 +64,7 @@ script: pushd docs make clean html linkcheck popd - if [[ -z "$TRAVIS_TAG" ]]; then + if [[ -z "$TRAVIS_TAG" ]]; then python -m doctr deploy --build-tags --key-path github_deploy_key.enc --built-docs docs/_build/html dev else python -m doctr deploy --build-tags --key-path github_deploy_key.enc --built-docs docs/_build/html "version-$TRAVIS_TAG" diff --git a/gsw/_utilities.py b/gsw/_utilities.py index 09af840..56e2a2e 100644 --- a/gsw/_utilities.py +++ b/gsw/_utilities.py @@ -26,20 +26,35 @@ def wrapper(*args, **kw): args = list(args) args.append(p) - isarray = np.any([hasattr(a, '__iter__') for a in args]) - ismasked = np.any([np.ma.isMaskedArray(a) for a in args]) + isarray = [hasattr(a, '__iter__') for a in args] + ismasked = [np.ma.isMaskedArray(a) for a in args] + isduck = [hasattr(a, '__array_ufunc__') + and not isinstance(a, np.ndarray) for a in args] + + hasarray = np.any(isarray) + hasmasked = np.any(ismasked) + hasduck = np.any(isduck) def fixup(ret): - if ismasked: + if hasduck: + return ret + if hasmasked: ret = np.ma.masked_invalid(ret) - if not isarray and isinstance(ret, np.ndarray): - ret = ret[0] + if not hasarray and isinstance(ret, np.ndarray) and ret.size == 1: + try: + ret = ret[0] + except IndexError: + pass return ret - if ismasked: - newargs = [masked_to_nan(a) for a in args] - else: - newargs = [np.asarray(a, dtype=float) for a in args] + newargs = [] + for i, arg in enumerate(args): + if ismasked[i]: + newargs.append(masked_to_nan(arg)) + elif isduck[i]: + newargs.append(arg) + else: + newargs.append(np.asarray(arg, dtype=float)) if p is not None: kw['p'] = newargs.pop() diff --git a/gsw/geostrophy.py b/gsw/geostrophy.py index 89e104a..7c7104e 100644 --- a/gsw/geostrophy.py +++ b/gsw/geostrophy.py @@ -73,7 +73,10 @@ def geo_strf_dyn_height(SA, CT, p, p_ref=0, axis=0, max_dp=1.0, dh = np.empty(SA.shape, dtype=float) dh.fill(np.nan) - order = 'F' if SA.flags.fortran else 'C' + try: + order = 'F' if SA.flags.fortran else 'C' + except AttributeError: + order = 'C' # e.g., xarray DataArray doesn't have flags for ind in indexer(SA.shape, axis, order=order): igood = goodmask[ind] # If p_ref is below the deepest value, skip the profile. diff --git a/gsw/tests/check_functions.py b/gsw/tests/check_functions.py index c6227ea..05bb56d 100644 --- a/gsw/tests/check_functions.py +++ b/gsw/tests/check_functions.py @@ -55,7 +55,7 @@ def find(x): """ Numpy equivalent to Matlab find. """ - return np.nonzero(x.flatten())[0] + return np.nonzero(np.asarray(x).flatten())[0] def group_or(line): diff --git a/gsw/tests/test_xarray.py b/gsw/tests/test_xarray.py new file mode 100644 index 0000000..f1b82a2 --- /dev/null +++ b/gsw/tests/test_xarray.py @@ -0,0 +1,143 @@ +""" +Tests functions with xarray inputs. + +This version is a copy of the original test_check_functions but with +an import of xarray, and conversion of the 3 main check cast arrays +into DataArray objects. + +An additional xarray-dask test is added. +""" + +import os +import pytest + +import numpy as np +from numpy.testing import assert_allclose + +import gsw +from gsw._utilities import Bunch +from check_functions import parse_check_functions + +xr = pytest.importorskip('xarray') + +# Most of the tests have some nan values, so we need to suppress the warning. +# Any more careful fix would likely require considerable effort. +np.seterr(invalid='ignore') + +root_path = os.path.abspath(os.path.dirname(__file__)) + +# Function checks that we can't handle automatically yet. +blacklist = ['deltaSA_atlas', # the test is complicated; doesn't fit the pattern. + 'geostrophic_velocity', # test elsewhere; we changed the API + #'CT_from_entropy', # needs prior entropy_from_CT; don't have it in C + #'CT_first_derivatives', # passes, but has trouble in "details"; + # see check_functions.py + #'entropy_second_derivatives', # OK now; handling extra parens. + #'melting_ice_into_seawater', # OK now; fixed nargs mismatch. + ] + +# We get an overflow from ct_from_enthalpy_exact, but the test passes. +cv = Bunch(np.load(os.path.join(root_path, 'gsw_cv_v3_0.npz'))) + +# Substitute new check values for the pchip interpolation version. +cv.geo_strf_dyn_height = np.load(os.path.join(root_path,'geo_strf_dyn_height.npy')) +cv.geo_strf_velocity = np.load(os.path.join(root_path,'geo_strf_velocity.npy')) + +for name in ['SA_chck_cast', 't_chck_cast', 'p_chck_cast']: + cv[name] = xr.DataArray(cv[name]) + +cf = Bunch() + +d = dir(gsw) +funcnames = [name for name in d if '__' not in name] + +mfuncs = parse_check_functions(os.path.join(root_path, 'gsw_check_functions_save.m')) +mfuncs = [mf for mf in mfuncs if mf.name in d and mf.name not in blacklist] +mfuncnames = [mf.name for mf in mfuncs] + + +@pytest.fixture(scope='session', params=mfuncs) +def cfcf(request): + return cv, cf, request.param + + +def test_check_function(cfcf): + cv, cf, mfunc = cfcf + mfunc.run(locals()) + if mfunc.exception is not None or not mfunc.passed: + print('\n', mfunc.name) + print(' ', mfunc.runline) + print(' ', mfunc.testline) + if mfunc.exception is None: + mfunc.exception = ValueError('Calculated values are different from the expected matlab results.') + raise mfunc.exception + else: + print(mfunc.name) + assert mfunc.passed + + +def test_dask_chunking(): + dsa = pytest.importorskip('dask.array') + + # define some input data + shape = (100, 1000) + chunks = (100, 200) + sp = xr.DataArray(dsa.full(shape, 35., chunks=chunks), dims=['time', 'depth']) + p = xr.DataArray(np.arange(shape[1]), dims=['depth']) + lon = 0 + lat = 45 + + sa = gsw.SA_from_SP(sp, p, lon, lat) + sa_dask = sa.compute() + + sa_numpy = gsw.SA_from_SP(np.full(shape, 35.0), p.values, lon, lat) + assert_allclose(sa_dask, sa_numpy) + + +# Additional tests from Graeme MacGilchrist +# https://nbviewer.jupyter.org/github/gmacgilchrist/wmt_bgc/blob/master/notebooks/test_gsw-xarray.ipynb + +# Define dimensions and coordinates +dims = ['y','z','t'] +# 2x2x2 +y = np.arange(0,2) +z = np.arange(0,2) +t = np.arange(0,2) +# Define numpy arrays of salinity, temperature and pressure +SA_vals = np.array([[[34.7,34.8],[34.9,35]],[[35.1,35.2],[35.3,35.4]]]) +CT_vals = np.array([[[7,8],[9,10]],[[11,12],[13,14]]]) +p_vals = np.array([10,20]) +lat_vals = np.array([0,10]) +# Plug in to xarray objects +SA = xr.DataArray(SA_vals,dims=dims,coords={'y':y,'z':z,'t':t}) +CT = xr.DataArray(CT_vals,dims=dims,coords={'y':y,'z':z,'t':t}) +p = xr.DataArray(p_vals,dims=['z'],coords={'z':z}) +lat = xr.DataArray(lat_vals,dims=['y'],coords={'y':y}) + + +def test_xarray_with_coords(): + pytest.importorskip('dask') + SA_chunk = SA.chunk(chunks={'y':1,'t':1}) + CT_chunk = CT.chunk(chunks={'y':1,'t':1}) + lat_chunk = lat.chunk(chunks={'y':1}) + + # Dimensions and cordinates match: + expected = gsw.sigma0(SA_vals, CT_vals) + xarray = gsw.sigma0(SA, CT) + chunked = gsw.sigma0(SA_chunk, CT_chunk) + assert_allclose(xarray, expected) + assert_allclose(chunked, expected) + + # Broadcasting along dimension required (dimensions known) + expected = gsw.alpha(SA_vals, CT_vals, p_vals[np.newaxis, :, np.newaxis]) + xarray = gsw.alpha(SA, CT, p) + chunked = gsw.alpha(SA_chunk, CT_chunk, p) + assert_allclose(xarray, expected) + assert_allclose(chunked, expected) + + # Broadcasting along dimension required (dimensions unknown/exclusive) + expected = gsw.z_from_p(p_vals[:, np.newaxis], lat_vals[np.newaxis, :]) + xarray = gsw.z_from_p(p, lat) + chunked = gsw.z_from_p(p,lat_chunk) + assert_allclose(xarray, expected) + assert_allclose(chunked, expected)