diff --git a/.github/dependabot.yml b/.github/dependabot.yml index b4b3fa44..2eacebc2 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,11 +1,11 @@ version: 2 updates: - # - package-ecosystem: pip - # directory: "/" - # schedule: - # interval: daily - - package-ecosystem: 'github-actions' - directory: '/' - schedule: - # Check for updates once a week - interval: 'weekly' + # - package-ecosystem: pip + # directory: "/" + # schedule: + # interval: daily + - package-ecosystem: 'github-actions' + directory: '/' + schedule: + # Check for updates once a week + interval: 'weekly' diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cc84cb70..9b970a3c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,94 +1,94 @@ name: CI on: - push: - branches: - - master - pull_request: - branches: - - '*' - schedule: - - cron: '0 0 * * *' # Daily “At 00:00” - workflow_dispatch: # allows you to trigger manually + push: + branches: + - master + pull_request: + branches: + - '*' + schedule: + - cron: '0 0 * * *' # Daily “At 00:00” + workflow_dispatch: # allows you to trigger manually jobs: - build: - runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} - strategy: - fail-fast: false - matrix: - include: - # Warning: Unless in quotations, numbers below are read as floats. 3.10 < 3.2 - - python-version: '3.8' - esmf-version: 8.2 - - python-version: '3.9' - esmf-version: 8.3 - - python-version: '3.10' - esmf-version: 8.4 - - python-version: '3.11' - esmf-version: 8.4 - steps: - - name: Cancel previous runs - uses: styfle/cancel-workflow-action@0.12.0 - with: - access_token: ${{ github.token }} - - name: Checkout source - uses: actions/checkout@v4 - - name: Create conda environment - uses: mamba-org/provision-with-micromamba@main - with: - cache-downloads: true - micromamba-version: 'latest' - environment-file: ci/environment.yml - extra-specs: | - python=${{ matrix.python-version }} - esmpy=${{ matrix.esmf-version }} - - name: Install Xesmf (editable) - run: | - python -m pip install --no-deps -e . - - name: Conda list information - run: | - conda env list - conda list - - name: Run tests - run: | - python -m pytest --cov=./ --cov-report=xml --verbose - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3.1.3 - with: - file: ./coverage.xml - fail_ci_if_error: false + build: + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: false + matrix: + include: + # Warning: Unless in quotations, numbers below are read as floats. 3.10 < 3.2 + - python-version: '3.8' + esmf-version: 8.2 + - python-version: '3.9' + esmf-version: 8.3 + - python-version: '3.10' + esmf-version: 8.4 + - python-version: '3.11' + esmf-version: 8.4 + steps: + - name: Cancel previous runs + uses: styfle/cancel-workflow-action@0.12.0 + with: + access_token: ${{ github.token }} + - name: Checkout source + uses: actions/checkout@v4 + - name: Create conda environment + uses: mamba-org/provision-with-micromamba@main + with: + cache-downloads: true + micromamba-version: 'latest' + environment-file: ci/environment.yml + extra-specs: | + python=${{ matrix.python-version }} + esmpy=${{ matrix.esmf-version }} + - name: Install Xesmf (editable) + run: | + python -m pip install --no-deps -e . + - name: Conda list information + run: | + conda env list + conda list + - name: Run tests + run: | + python -m pytest --cov=./ --cov-report=xml --verbose + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3.1.3 + with: + file: ./coverage.xml + fail_ci_if_error: false - upstream-dev: - name: upstream-dev - runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} - steps: - - name: Cancel previous runs - uses: styfle/cancel-workflow-action@0.12.0 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v4 - - name: Create conda environment - uses: mamba-org/provision-with-micromamba@v16 - with: - cache-downloads: true - micromamba-version: 'latest' - environment-file: ci/environment-upstream-dev.yml - extra-specs: | - python=3.10 - - name: Install Xesmf (editable) - run: | - python -m pip install -e . - - name: Conda list information - run: | - conda env list - conda list - - name: Run tests - run: | - python -m pytest --cov=./ --cov-report=xml --verbose + upstream-dev: + name: upstream-dev + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + steps: + - name: Cancel previous runs + uses: styfle/cancel-workflow-action@0.12.0 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v4 + - name: Create conda environment + uses: mamba-org/provision-with-micromamba@v16 + with: + cache-downloads: true + micromamba-version: 'latest' + environment-file: ci/environment-upstream-dev.yml + extra-specs: | + python=3.10 + - name: Install Xesmf (editable) + run: | + python -m pip install -e . + - name: Conda list information + run: | + conda env list + conda list + - name: Run tests + run: | + python -m pytest --cov=./ --cov-report=xml --verbose diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml index 2391406f..c7c018ea 100644 --- a/.github/workflows/linting.yaml +++ b/.github/workflows/linting.yaml @@ -1,18 +1,18 @@ name: linting on: - push: - branches: - - master - pull_request: - branches: '*' + push: + branches: + - master + pull_request: + branches: '*' jobs: - linting: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - with: - python-version: '3.x' - - uses: pre-commit/action@v3.0.0 + linting: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: '3.x' + - uses: pre-commit/action@v3.0.0 diff --git a/.github/workflows/pypi.yaml b/.github/workflows/pypi.yaml index fde7fd5c..962a1b7b 100644 --- a/.github/workflows/pypi.yaml +++ b/.github/workflows/pypi.yaml @@ -1,55 +1,55 @@ name: Publish to PyPI on: - pull_request: - push: - branches: - - master - release: - types: - - published + pull_request: + push: + branches: + - master + release: + types: + - published defaults: - run: - shell: bash + run: + shell: bash jobs: - packages: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.x' - - - name: Get tags - run: git fetch --depth=1 origin +refs/tags/*:refs/tags/* - - - name: Install build tools - run: | - python -m pip install --upgrade build - - - name: Build binary wheel - run: python -m build --sdist --wheel . --outdir dist - - - name: CheckFiles - run: | - ls dist - python -m pip install --upgrade check-manifest - check-manifest --verbose - - - name: Test wheels - run: | - # We cannot run this step b/c esmpy is not available on PyPI - # cd dist && python -m pip install *.whl && cd .. - python -m pip install --upgrade build twine - python -m twine check dist/* - - - name: Publish a Python distribution to PyPI - if: success() && github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@release/v1 - with: - user: __token__ - password: ${{ secrets.PYPI_TOKEN }} + packages: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Get tags + run: git fetch --depth=1 origin +refs/tags/*:refs/tags/* + + - name: Install build tools + run: | + python -m pip install --upgrade build + + - name: Build binary wheel + run: python -m build --sdist --wheel . --outdir dist + + - name: CheckFiles + run: | + ls dist + python -m pip install --upgrade check-manifest + check-manifest --verbose + + - name: Test wheels + run: | + # We cannot run this step b/c esmpy is not available on PyPI + # cd dist && python -m pip install *.whl && cd .. + python -m pip install --upgrade build twine + python -m twine check dist/* + + - name: Publish a Python distribution to PyPI + if: success() && github.event_name == 'release' + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a190da46..f148794f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,70 +1,70 @@ default_language_version: - python: python3 + python: python3 repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - - id: check-docstring-first - - id: check-json - - id: check-yaml - - id: double-quote-string-fixer + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-docstring-first + - id: check-json + - id: check-yaml + - id: double-quote-string-fixer - - repo: https://github.com/psf/black - rev: 23.9.1 - hooks: - - id: black + - repo: https://github.com/psf/black + rev: 23.9.1 + hooks: + - id: black - - repo: https://github.com/keewis/blackdoc - rev: v0.3.8 - hooks: - - id: blackdoc + - repo: https://github.com/keewis/blackdoc + rev: v0.3.8 + hooks: + - id: blackdoc - - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 - hooks: - - id: flake8 + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 + hooks: + - id: flake8 - - repo: https://github.com/asottile/seed-isort-config - rev: v2.2.0 - hooks: - - id: seed-isort-config - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort + - repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.3 - hooks: - - id: prettier + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v3.0.3 + hooks: + - id: prettier - - repo: https://github.com/deathbeds/prenotebook - rev: f5bdb72a400f1a56fe88109936c83aa12cc349fa - hooks: - - id: prenotebook - args: - [ - '--keep-output', - '--keep-metadata', - '--keep-execution-count', - '--keep-empty', - ] + - repo: https://github.com/deathbeds/prenotebook + rev: f5bdb72a400f1a56fe88109936c83aa12cc349fa + hooks: + - id: prenotebook + args: + [ + '--keep-output', + '--keep-metadata', + '--keep-execution-count', + '--keep-empty', + ] - - repo: https://github.com/tox-dev/pyproject-fmt - rev: 1.2.0 - hooks: - - id: pyproject-fmt + - repo: https://github.com/tox-dev/pyproject-fmt + rev: 1.2.0 + hooks: + - id: pyproject-fmt ci: - autofix_commit_msg: | - [pre-commit.ci] auto fixes from pre-commit.com hooks + autofix_commit_msg: | + [pre-commit.ci] auto fixes from pre-commit.com hooks - for more information, see https://pre-commit.ci - autofix_prs: true - autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' - autoupdate_schedule: monthly - skip: [] - submodules: false + for more information, see https://pre-commit.ci + autofix_prs: true + autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' + autoupdate_schedule: monthly + skip: [] + submodules: false diff --git a/.prettierrc.toml b/.prettierrc.toml index addd6d36..24a4663a 100644 --- a/.prettierrc.toml +++ b/.prettierrc.toml @@ -1,3 +1,3 @@ -tabWidth = 2 +tabWidth = 4 semi = false singleQuote = true diff --git a/binder/environment.yml b/binder/environment.yml index bcb6ee6e..42e1e14b 100644 --- a/binder/environment.yml +++ b/binder/environment.yml @@ -1,15 +1,15 @@ channels: - - conda-forge + - conda-forge dependencies: - - python=3.7 - - esmpy==7.1.0r - - xarray - - dask - - numpy - - scipy - - shapely - - matplotlib - - cartopy - - cf_xarray>=0.3.1 - - pip: - - xesmf==0.2.2 + - python=3.7 + - esmpy==7.1.0r + - xarray + - dask + - numpy + - scipy + - shapely + - matplotlib + - cartopy + - cf_xarray>=0.3.1 + - pip: + - xesmf==0.2.2 diff --git a/ci/environment-upstream-dev.yml b/ci/environment-upstream-dev.yml index 659aee2a..3d0f2590 100644 --- a/ci/environment-upstream-dev.yml +++ b/ci/environment-upstream-dev.yml @@ -1,20 +1,20 @@ name: xesmf channels: - - conda-forge + - conda-forge dependencies: - - cftime - - codecov - - dask - - esmpy - - numba - - numpy - - pip - - pre-commit - - pydap - - pytest - - pytest-cov - - shapely - - sparse>=0.8.0 - - pip: - - git+https://github.com/pydata/xarray.git - - git+https://github.com/xarray-contrib/cf-xarray.git + - cftime + - codecov + - dask + - esmpy + - numba + - numpy + - pip + - pre-commit + - pydap + - pytest + - pytest-cov + - shapely + - sparse>=0.8.0 + - pip: + - git+https://github.com/pydata/xarray.git + - git+https://github.com/xarray-contrib/cf-xarray.git diff --git a/ci/environment.yml b/ci/environment.yml index b596fcaa..8f57d1ba 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -1,19 +1,19 @@ name: xesmf channels: - - conda-forge + - conda-forge dependencies: - - cf_xarray>=0.3.1 - - cftime - - codecov - - dask - - esmpy - - numba - - numpy - - pip - - pre-commit - - pydap - - pytest - - pytest-cov - - shapely - - sparse>=0.8.0 - - xarray>=0.17.0 + - cf_xarray>=0.3.1 + - cftime + - codecov + - dask + - esmpy + - numba + - numpy + - pip + - pre-commit + - pydap + - pytest + - pytest-cov + - shapely + - sparse>=0.8.0 + - xarray>=0.17.0 diff --git a/codecov.yml b/codecov.yml index 1e11ad52..d241151b 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,20 +1,20 @@ codecov: - require_ci_to_pass: no - max_report_age: off + require_ci_to_pass: no + max_report_age: off comment: false ignore: - - 'xesmf/tests/*' - - 'setup.py' + - 'xesmf/tests/*' + - 'setup.py' coverage: - precision: 2 - round: down - status: - project: - default: - target: 95 - informational: true - patch: off - changes: off + precision: 2 + round: down + status: + project: + default: + target: 95 + informational: true + patch: off + changes: off diff --git a/doc/notebooks/Compare_algorithms.ipynb b/doc/notebooks/Compare_algorithms.ipynb index 0e194893..f5dcb746 100644 --- a/doc/notebooks/Compare_algorithms.ipynb +++ b/doc/notebooks/Compare_algorithms.ipynb @@ -8,12 +8,12 @@ "\n", "xESMF exposes five different regridding algorithms from the ESMF library:\n", "\n", - "- `bilinear`: `ESMF.RegridMethod.BILINEAR`\n", - "- `conservative`: `ESMF.RegridMethod.CONSERVE`\n", - "- `conservative_normed`: `ESMF.RegridMethod.CONSERVE`\n", - "- `patch`: `ESMF.RegridMethod.PATCH`\n", - "- `nearest_s2d`: `ESMF.RegridMethod.NEAREST_STOD`\n", - "- `nearest_d2s`: `ESMF.RegridMethod.NEAREST_DTOS`\n", + "- `bilinear`: `ESMF.RegridMethod.BILINEAR`\n", + "- `conservative`: `ESMF.RegridMethod.CONSERVE`\n", + "- `conservative_normed`: `ESMF.RegridMethod.CONSERVE`\n", + "- `patch`: `ESMF.RegridMethod.PATCH`\n", + "- `nearest_s2d`: `ESMF.RegridMethod.NEAREST_STOD`\n", + "- `nearest_d2s`: `ESMF.RegridMethod.NEAREST_DTOS`\n", "\n", "where `conservative_normed` is just the `conservative` method with the\n", "normalization set to `ESMF.NormType.FRACAREA` instead of the default\n", @@ -23,26 +23,27 @@ "\n", "## Notes\n", "\n", - "- `bilinear` and `conservative` should be the most commonly used methods. They\n", - " are both monotonic (i.e. will not create new maximum/minimum).\n", - "- Nearest neighbour methods, either source to destination (s2d) or destination\n", - " to source (d2s), could be useful in special cases. Keep in mind that d2s is\n", - " highly non-monotonic.\n", - "- Patch is ESMF's unique method, producing highly smooth results but quite slow.\n", - "- From the ESMF documentation:\n", + "- `bilinear` and `conservative` should be the most commonly used methods. They\n", + " are both monotonic (i.e. will not create new maximum/minimum).\n", + "- Nearest neighbour methods, either source to destination (s2d) or destination\n", + " to source (d2s), could be useful in special cases. Keep in mind that d2s is\n", + " highly non-monotonic.\n", + "- Patch is ESMF's unique method, producing highly smooth results but quite\n", + " slow.\n", + "- From the ESMF documentation:\n", "\n", - " > The weight $w_{ij}$ for a particular source cell $i$ and destination cell\n", - " > $j$ are calculated as $w_{ij}=f_{ij} * A_{si}/A_{dj}$. In this equation\n", - " > $f_{ij}$ is the fraction of the source cell $i$ contributing to destination\n", - " > cell $j$, and $A_{si}$ and $A_{dj}$ are the areas of the source and\n", - " > destination cells.\n", + " > The weight $w_{ij}$ for a particular source cell $i$ and destination cell\n", + " > $j$ are calculated as $w_{ij}=f_{ij} * A_{si}/A_{dj}$. In this equation\n", + " > $f_{ij}$ is the fraction of the source cell $i$ contributing to\n", + " > destination cell $j$, and $A_{si}$ and $A_{dj}$ are the areas of the\n", + " > source and destination cells.\n", "\n", - " For `conservative_normed`,\n", + " For `conservative_normed`,\n", "\n", - " > ... then the weights are further divided by the destination fraction. In\n", - " > other words, in that case $w_{ij}=f_{ij} * A_{si}/(A_{dj}*D_j)$ where $D_j$\n", - " > is fraction of the destination cell that intersects the unmasked source\n", - " > grid.\n", + " > ... then the weights are further divided by the destination fraction. In\n", + " > other words, in that case $w_{ij}=f_{ij} * A_{si}/(A_{dj}*D_j)$ where\n", + " > $D_j$ is fraction of the destination cell that intersects the unmasked\n", + " > source grid.\n", "\n", "Detailed explanations are available on\n", "[ESMPy documentation](http://www.earthsystemmodeling.org/esmf_releases/last_built/esmpy_doc/html/api.html#regridding).\n", diff --git a/doc/notebooks/Masking.ipynb b/doc/notebooks/Masking.ipynb index 0dcc213d..8c134190 100644 --- a/doc/notebooks/Masking.ipynb +++ b/doc/notebooks/Masking.ipynb @@ -499,15 +499,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "- mask can only be 2D (ESMF design) so regridding a 3D field requires to\n", - " generate regridding weights for each vertical level.\n", + "- mask can only be 2D (ESMF design) so regridding a 3D field requires to\n", + " generate regridding weights for each vertical level.\n", "\n", - "- conservative method will give you a normalization by the total area of the\n", - " target cell. Except for some specific cases, you probably want to use\n", - " conservative_normed.\n", + "- conservative method will give you a normalization by the total area of the\n", + " target cell. Except for some specific cases, you probably want to use\n", + " conservative_normed.\n", "\n", - "- results with other methods (e.g. bilinear) may not give masks consistent with\n", - " the coarse grid.\n" + "- results with other methods (e.g. bilinear) may not give masks consistent\n", + " with the coarse grid.\n" ] }, { diff --git a/doc/notebooks/Pure_numpy.ipynb b/doc/notebooks/Pure_numpy.ipynb index a04ab748..308eef62 100644 --- a/doc/notebooks/Pure_numpy.ipynb +++ b/doc/notebooks/Pure_numpy.ipynb @@ -236,10 +236,10 @@ "We use the previous input data, but now assume it is on a curvilinear grid\n", "described by 2D arrays. We also computed the cell corners, for two purposes:\n", "\n", - "- Visualization with `plt.pcolormesh` (using cell centers will miss one\n", - " row&column)\n", - "- Conservative regridding with xESMF (corner information is required for\n", - " conservative method)\n" + "- Visualization with `plt.pcolormesh` (using cell centers will miss one\n", + " row&column)\n", + "- Conservative regridding with xESMF (corner information is required for\n", + " conservative method)\n" ] }, { @@ -446,9 +446,9 @@ "source": [ "All $2 \\times 2\\times 2 = 8$ combinations would work:\n", "\n", - "- Input grid: `xarray.DataSet` or `dict`\n", - "- Output grid: `xarray.DataSet` or `dict`\n", - "- Input data: `xarray.DataArray` or `numpy.ndarray`\n", + "- Input grid: `xarray.DataSet` or `dict`\n", + "- Output grid: `xarray.DataSet` or `dict`\n", + "- Input data: `xarray.DataArray` or `numpy.ndarray`\n", "\n", "The output data type will be the same as input data.\n" ] diff --git a/readthedocs.yml b/readthedocs.yml index f0c4a4ae..c605ae18 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,12 +1,12 @@ version: 2 build: - os: ubuntu-22.04 - tools: - python: '3.9' + os: ubuntu-22.04 + tools: + python: '3.9' python: - install: - - requirements: doc/requirements.txt - - method: pip - path: . + install: + - requirements: doc/requirements.txt + - method: pip + path: . diff --git a/xesmf/backend.py b/xesmf/backend.py index 55294181..1ce05424 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -17,16 +17,20 @@ import os import warnings +from typing import List, Literal, Optional, Sequence, Union try: import esmpy as ESMF except ImportError: import ESMF + import numpy as np import numpy.lib.recfunctions as nprec +import numpy.typing as npt +from shapely.geometry import Polygon -def warn_f_contiguous(a): +def warn_f_contiguous(a: npt.NDArray) -> None: """ Give a warning if input array if not Fortran-ordered. @@ -41,7 +45,7 @@ def warn_f_contiguous(a): warnings.warn('Input array is not F_CONTIGUOUS. ' 'Will affect performance.') -def warn_lat_range(lat): +def warn_lat_range(lat: npt.NDArray) -> None: """ Give a warning if latitude is outside of [-90, 90] @@ -58,19 +62,25 @@ def warn_lat_range(lat): class Grid(ESMF.Grid): @classmethod - def from_xarray(cls, lon, lat, periodic=False, mask=None): + def from_xarray( + cls, + lon: npt.NDArray, + lat: npt.NDArray, + periodic: bool = False, + mask: Optional[npt.NDArray] = None, + ): """ Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid. Parameters ---------- lon, lat : 2D numpy array - Longitute/Latitude of cell centers. + Longitute/Latitude of cell centers. - Recommend Fortran-ordering to match ESMPy internal. + Recommend Fortran-ordering to match ESMPy internal. - Shape should be ``(Nlon, Nlat)`` for rectilinear grid, - or ``(Nx, Ny)`` for general quadrilateral grid. + Shape should be ``(Nlon, Nlat)`` for rectilinear grid, + or ``(Nx, Ny)`` for general quadrilateral grid. periodic : bool, optional Periodic in longitude? Default to False. @@ -136,8 +146,8 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None): grid_mask = mask.astype(np.int32) if not (grid_mask.shape == lon.shape): raise ValueError( - 'mask must have the same shape as the latitude/longitude' - 'coordinates, got: mask.shape = %s, lon.shape = %s' % (mask.shape, lon.shape) + 'mask must have the same shape as the latitude/longitude coordinates,' + f'got: mask.shape = {mask.shape}, lon.shape = {lon.shape}' ) grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, from_file=False) grid.mask[0][:] = grid_mask @@ -151,14 +161,18 @@ def get_shape(self, loc=ESMF.StaggerLoc.CENTER): class LocStream(ESMF.LocStream): @classmethod - def from_xarray(cls, lon, lat): + def from_xarray( + cls, + lon: npt.NDArray, + lat: npt.NDArray, + ) -> ESMF.LocStream: """ Create an ESMF.LocStream object, for contrusting ESMF.Field and ESMF.Regrid Parameters ---------- lon, lat : 1D numpy array - Longitute/Latitude of cell centers. + Longitute/Latitude of cell centers. Returns ------- @@ -186,7 +200,7 @@ def get_shape(self): return (self.size, 1) -def add_corner(grid, lon_b, lat_b): +def add_corner(grid: Grid, lon_b, lat_b): """ Add corner information to ESMF.Grid for conservative regridding. @@ -204,7 +218,7 @@ def add_corner(grid, lon_b, lat_b): """ # codes here are almost the same as Grid.from_xarray(), - # except for the "staggerloc" keyword + # except for the 'staggerloc' keyword staggerloc = ESMF.StaggerLoc.CORNER # actually just integer 3 for a in [lon_b, lat_b]: @@ -230,7 +244,11 @@ def add_corner(grid, lon_b, lat_b): class Mesh(ESMF.Mesh): @classmethod - def from_polygons(cls, polys, element_coords='centroid'): + def from_polygons( + cls, + polys: Sequence[Polygon], + element_coords: Union[Literal['centroid'], npt.NDArray] = 'centroid', + ): """ Create an ESMF.Mesh object from a list of polygons. @@ -240,9 +258,9 @@ def from_polygons(cls, polys, element_coords='centroid'): Parameters ---------- polys : sequence of shapely Polygon - Holes are not represented by the Mesh. - element_coords : array or "centroid", optional - If "centroid", the polygon centroids will be used (default) + Holes are not represented by the Mesh. + element_coords : array or 'centroid', optional + If 'centroid', the polygon centroids will be used (default) If an array of shape (len(polys), 2) : the element coordinates of the mesh. If None, the Mesh's elements will not have coordinates. @@ -313,15 +331,22 @@ def get_shape(self, loc=ESMF.MeshLoc.ELEMENT): def esmf_regrid_build( - sourcegrid, - destgrid, - method, - filename=None, - extra_dims=None, - extrap_method=None, - extrap_dist_exponent=None, - extrap_num_src_pnts=None, - ignore_degenerate=None, + sourcegrid: Union[Grid, Mesh], + destgrid: Union[Grid, Mesh], + method: Literal[ + 'bilinear', + 'conservative', + 'conservative_normed', + 'patch', + 'nearest_s2d', + 'nearest_d2s', + ], + filename: Union[str, None] = None, + extra_dims: Union[List[int], None] = None, + extrap_method: Union[Literal['inverse_dist', 'nearest_s2d'], None] = None, + extrap_dist_exponent: float = 2.0, + extrap_num_src_pnts: int = 8, + ignore_degenerate: bool = False, ): """ Create an ESMF.Regrid object, containing regridding weights. @@ -387,7 +412,7 @@ def esmf_regrid_build( """ # use shorter, clearer names for options in ESMF.RegridMethod - method_dict = { + method_dict: dict[str, int] = { 'bilinear': ESMF.RegridMethod.BILINEAR, 'conservative': ESMF.RegridMethod.CONSERVE, 'conservative_normed': ESMF.RegridMethod.CONSERVE, @@ -398,7 +423,7 @@ def esmf_regrid_build( try: esmf_regrid_method = method_dict[method] except Exception: - raise ValueError('method should be chosen from ' '{}'.format(list(method_dict.keys()))) + raise ValueError(f'method should be chosen from {list(method_dict.keys())}') # use shorter, clearer names for options in ESMF.ExtrapMethod extrap_dict = { @@ -409,9 +434,7 @@ def esmf_regrid_build( try: esmf_extrap_method = extrap_dict[extrap_method] except KeyError: - raise KeyError( - '`extrap_method` should be chosen from ' '{}'.format(list(extrap_dict.keys())) - ) + raise KeyError(f'`extrap_method` should be chosen from {list(extrap_dict.keys())}') # until ESMPy updates ESMP_FieldRegridStoreFile, extrapolation is not possible # if files are written on disk @@ -420,11 +443,11 @@ def esmf_regrid_build( # conservative regridding needs cell corner information if method in ['conservative', 'conservative_normed']: - if not isinstance(sourcegrid, ESMF.Mesh) and not sourcegrid.has_corners: + if not isinstance(sourcegrid, Mesh) and not sourcegrid.has_corners: raise ValueError( 'source grid has no corner information. ' 'cannot use conservative regridding.' ) - if not isinstance(destgrid, ESMF.Mesh) and not destgrid.has_corners: + if not isinstance(destgrid, Mesh) and not destgrid.has_corners: raise ValueError( 'destination grid has no corner information. ' 'cannot use conservative regridding.' ) @@ -432,11 +455,11 @@ def esmf_regrid_build( # ESMF.Regrid requires Field (Grid+data) as input, not just Grid. # Extra dimensions are specified when constructing the Field objects, # not when constructing the Regrid object later on. - if isinstance(sourcegrid, ESMF.Mesh): + if isinstance(sourcegrid, Mesh): sourcefield = ESMF.Field(sourcegrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims) else: sourcefield = ESMF.Field(sourcegrid, ndbounds=extra_dims) - if isinstance(destgrid, ESMF.Mesh): + if isinstance(destgrid, Mesh): destfield = ESMF.Field(destgrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims) else: destfield = ESMF.Field(destgrid, ndbounds=extra_dims) @@ -485,7 +508,7 @@ def esmf_regrid_build( return regrid -def esmf_regrid_apply(regrid, indata): +def esmf_regrid_apply(regrid: ESMF.Regrid, indata): """ Apply existing regridding weights to the data field, using ESMPy's built-in functionality. @@ -533,7 +556,7 @@ def esmf_regrid_apply(regrid, indata): return destfield.data -def esmf_regrid_finalize(regrid): +def esmf_regrid_finalize(regrid: ESMF.Regrid): """ Free the underlying Fortran array to avoid memory leak. @@ -563,7 +586,10 @@ def esmf_regrid_finalize(regrid): # Deprecated as of version 0.5.0 -def esmf_locstream(lon, lat): +def esmf_locstream( + lon: npt.NDArray, + lat: npt.NDArray, +) -> LocStream: warnings.warn( '`esmf_locstream` is being deprecated in favor of `LocStream.from_xarray`', DeprecationWarning, @@ -571,8 +597,14 @@ def esmf_locstream(lon, lat): return LocStream.from_xarray(lon, lat) -def esmf_grid(lon, lat, periodic=False, mask=None): +def esmf_grid( + lon: npt.NDArray, + lat: npt.NDArray, + periodic: bool = False, + mask: Optional[npt.NDArray] = None, +) -> Grid: warnings.warn( - '`esmf_grid` is being deprecated in favor of `Grid.from_xarray`', DeprecationWarning + '`esmf_grid` is being deprecated in favor of `Grid.from_xarray`', + DeprecationWarning, ) return Grid.from_xarray(lon, lat) diff --git a/xesmf/data.py b/xesmf/data.py index e534c3a1..a20918ac 100644 --- a/xesmf/data.py +++ b/xesmf/data.py @@ -2,43 +2,49 @@ Standard test data for regridding benchmark. """ +from typing import Union + import numpy as np +import numpy.typing as npt +import xarray -def wave_smooth(lon, lat): - r""" +def wave_smooth( # type: ignore + lon: Union[npt.NDArray, xarray.DataArray], + lat: Union[npt.NDArray, xarray.DataArray], +) -> Union[npt.NDArray, xarray.DataArray]: + """ Spherical harmonic with low frequency. Parameters ---------- lon, lat : 2D numpy array or xarray DataArray - Longitute/Latitude of cell centers + Longitude/Latitude of cell centers Returns ------- - f : 2D numpy array or xarray DataArray depending on input - 2D wave field + f : 2D numpy array or xarray DataArray depending on input2D wave field Notes ------- Equation from [1]_ [2]_: - .. math:: Y_2^2 = 2 + \cos^2(\\theta) \cos(2 \phi) + .. math:: Y_2^2 = 2 + cos^2(lat) * cos(2 * lon) References ---------- .. [1] Jones, P. W. (1999). First-and second-order conservative remapping - schemes for grids in spherical coordinates. Monthly Weather Review, - 127(9), 2204-2210. + schemes for grids in spherical coordinates. Monthly Weather Review, + 127(9), 2204-2210. .. [2] Ullrich, P. A., Lauritzen, P. H., & Jablonowski, C. (2009). - Geometrically exact conservative remapping (GECoRe): regular - latitude–longitude and cubed-sphere grids. Monthly Weather Review, - 137(6), 1721-1741. + Geometrically exact conservative remapping (GECoRe): regular + latitude-longitude and cubed-sphere grids. Monthly Weather Review, + 137(6), 1721-1741. """ # degree to radius, make a copy - lat = lat / 180.0 * np.pi - lon = lon / 180.0 * np.pi + lat *= np.pi / 180.0 + lon *= np.pi / 180.0 - f = 2 + np.cos(lat) ** 2 * np.cos(2 * lon) + f = 2 + pow(np.cos(lat), 2) * np.cos(2 * lon) return f diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 267ed3d9..91570a16 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -3,12 +3,14 @@ """ import warnings +from typing import Any, Dict, Hashable, List, Literal, Optional, Sequence, Tuple, Union import cf_xarray as cfxr import numpy as np +import numpy.typing as npt import sparse as sps import xarray as xr -from shapely.geometry import LineString +from shapely.geometry import LineString, MultiPolygon, Polygon from xarray import DataArray, Dataset from .backend import Grid, LocStream, Mesh, add_corner, esmf_regrid_build, esmf_regrid_finalize @@ -31,7 +33,22 @@ def subset_regridder( - ds_out, ds_in, method, in_dims, out_dims, locstream_in, locstream_out, periodic, **kwargs + ds_in: Union[DataArray, Dataset, Dict[str, DataArray]], + ds_out: Union[DataArray, Dataset, Dict[str, DataArray]], + method: Literal[ + 'bilinear', + 'conservative', + 'conservative_normed', + 'patch', + 'nearest_s2d', + 'nearest_d2s', + ], + in_dims, + out_dims, + locstream_in: bool, + locstream_out: bool, + periodic: bool, + **kwargs, ): """Compute subset of weights""" kwargs.pop('filename', None) # Don't save subset of weights @@ -49,12 +66,22 @@ def subset_regridder( ds_out = ds_out.rename({'y_out': out_dims[0], 'x_out': out_dims[1]}) regridder = Regridder( - ds_in, ds_out, method, locstream_in, locstream_out, periodic, parallel=False, **kwargs + ds_in=ds_in, + ds_out=ds_out, + method=method, + locstream_in=locstream_in, + locstream_out=locstream_out, + periodic=periodic, + parallel=False, + **kwargs, ) return regridder.w -def as_2d_mesh(lon, lat): +def as_2d_mesh( + lon: Union[DataArray, npt.NDArray], + lat: Union[DataArray, npt.NDArray], +) -> Tuple[Union[DataArray, npt.NDArray], Union[DataArray, npt.NDArray]]: if (lon.ndim, lat.ndim) == (2, 2): assert lon.shape == lat.shape, 'lon and lat should have same shape' elif (lon.ndim, lat.ndim) == (1, 1): @@ -65,7 +92,9 @@ def as_2d_mesh(lon, lat): return lon, lat -def _get_lon_lat(ds): +def _get_lon_lat( + ds: Union[Dataset, Dict[str, npt.NDArray]] +) -> Tuple[Union[DataArray, npt.NDArray], Union[DataArray, npt.NDArray]]: """Return lon and lat extracted from ds.""" if ('lat' in ds and 'lon' in ds) or ('lat' in ds.coords and 'lon' in ds.coords): # Old way. @@ -82,7 +111,9 @@ def _get_lon_lat(ds): return lon, lat -def _get_lon_lat_bounds(ds): +def _get_lon_lat_bounds( + ds: Union[Dataset, Dict[str, npt.NDArray]] +) -> Tuple[Union[DataArray, npt.NDArray], Union[DataArray, npt.NDArray]]: """Return bounds of lon and lat extracted from ds.""" if 'lat_b' in ds and 'lon_b' in ds: # Old way. @@ -96,7 +127,7 @@ def _get_lon_lat_bounds(ds): lat_bnds = ds.cf.get_bounds('latitude') except KeyError: # bounds are not already present if ds.cf['longitude'].ndim > 1: - # We cannot infer 2D bounds, raise KeyError as custom "lon_b" is missing. + # We cannot infer 2D bounds, raise KeyError as custom 'lon_b' is missing. raise KeyError('lon_b') lon_name = ds.cf['longitude'].name lat_name = ds.cf['latitude'].name @@ -111,7 +142,12 @@ def _get_lon_lat_bounds(ds): return lon_b, lat_b -def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): +def ds_to_ESMFgrid( + ds: Union[Dataset, Dict[str, npt.NDArray]], + need_bounds: bool = False, + periodic: bool = None, + append=None, +): """ Convert xarray DataSet or dictionary to ESMF.Grid object. @@ -205,7 +241,7 @@ def ds_to_ESMFlocstream(ds): return locstream, (1,) + lon.shape, dim_names -def polys_to_ESMFmesh(polys): +def polys_to_ESMFmesh(polys) -> Tuple[Mesh, Tuple[Literal[1], int]]: """ Convert a sequence of shapely Polygons to a ESMF.Mesh object. @@ -234,21 +270,21 @@ def polys_to_ESMFmesh(polys): class BaseRegridder(object): def __init__( self, - grid_in, - grid_out, - method, - filename=None, - reuse_weights=False, - extrap_method=None, - extrap_dist_exponent=None, - extrap_num_src_pnts=None, - weights=None, - ignore_degenerate=None, - input_dims=None, - output_dims=None, - unmapped_to_nan=False, - parallel=False, - ): + grid_in: Union[Grid, LocStream, Mesh], + grid_out: Union[Grid, LocStream, Mesh], + method: str, + filename: Optional[str] = None, + reuse_weights: bool = False, + extrap_method: Optional[Literal['inverse_dist', 'nearest_s2d']] = None, + extrap_dist_exponent: Optional[float] = None, + extrap_num_src_pnts: Optional[int] = None, + weights: Optional[Any] = None, + ignore_degenerate: bool = False, + input_dims: Optional[Tuple[str, ...]] = None, + output_dims: Optional[Tuple[str, ...]] = None, + unmapped_to_nan: bool = False, + parallel: bool = False, + ) -> None: """ Base xESMF regridding class supporting ESMF objects: `Grid`, `Mesh` and `LocStream`. @@ -298,10 +334,10 @@ def __init__( weights : None, coo_matrix, dict, str, Dataset, Path, Regridding weights, stored as - - a scipy.sparse COO matrix, - - a dictionary with keys `row_dst`, `col_src` and `weights`, - - an xarray Dataset with data variables `col`, `row` and `S`, - - or a path to a netCDF file created by ESMF. + - a scipy.sparse COO matrix, + - a dictionary with keys `row_dst`, `col_src` and `weights`, + - an xarray Dataset with data variables `col`, `row` and `S`, + - or a path to a netCDF file created by ESMF. If None, compute the weights. ignore_degenerate : bool, optional @@ -392,7 +428,7 @@ def __init__( self.filename = self._get_default_filename() if filename is None else filename @property - def A(self): + def A(self) -> DataArray: message = ( 'regridder.A is deprecated and will be removed in future versions. ' 'Use regridder.weights instead.' @@ -418,7 +454,7 @@ def w(self) -> xr.DataArray: dims = 'y_out', 'x_out', 'y_in', 'x_in' return xr.DataArray(data, dims=dims) - def _get_default_filename(self): + def _get_default_filename(self) -> str: # e.g. bilinear_400x600_300x400.nc filename = '{0}_{1}x{2}_{3}x{4}'.format( self.method, @@ -450,7 +486,14 @@ def _compute_weights(self): esmf_regrid_finalize(regrid) # only need weights, not regrid object return w - def __call__(self, indata, keep_attrs=False, skipna=False, na_thres=1.0, output_chunks=None): + def __call__( + self, + indata: Union[npt.NDArray, 'da.Array', xr.DataArray, xr.Dataset], + keep_attrs: bool = False, + skipna: bool = False, + na_thres: float = 1.0, + output_chunks: Optional[Union[Dict[str, int], Tuple[int, ...]]] = None, + ): """ Apply regridding to input data. @@ -553,10 +596,18 @@ def __call__(self, indata, keep_attrs=False, skipna=False, na_thres=1.0, output_ raise TypeError('input must be numpy array, dask array, xarray DataArray or Dataset!') @staticmethod - def _regrid(indata, weights, *, shape_in, shape_out, skipna, na_thres): + def _regrid( + indata: npt.NDArray, + weights: sps.COO, + *, + shape_in: Tuple[int, int], + shape_out: Tuple[int, int], + skipna: bool, + na_thres: float, + ) -> npt.NDArray: # skipna: set missing values to zero if skipna: - missing = np.isnan(indata) + missing: npt.NDArray = np.isnan(indata) indata = np.where(missing, 0.0, indata) # apply weights @@ -572,14 +623,21 @@ def _regrid(indata, weights, *, shape_in, shape_out, skipna, na_thres): return outdata - def regrid_array(self, indata, weights, skipna=False, na_thres=1.0, output_chunks=None): + def regrid_array( + self, + indata: Union[npt.NDArray, 'da.Array'], + weights: sps.COO, + skipna: bool = False, + na_thres: float = 1.0, + output_chunks: Optional[Union[Tuple[int, ...], Dict[str, int]]] = None, + ): """See __call__().""" if self.sequence_in: indata = np.reshape(indata, (*indata.shape[:-1], 1, indata.shape[-1])) # If output_chunk is dict, order output chunks to match order of out_horiz_dims and convert to tuple if isinstance(output_chunks, dict): - output_chunks = tuple([output_chunks.get(key) for key in self.out_horiz_dims]) + output_chunks = tuple(map(output_chunks.get, self.out_horiz_dims)) kwargs = { 'shape_in': self.shape_in, @@ -594,7 +652,7 @@ def regrid_array(self, indata, weights, skipna=False, na_thres=1.0, output_chunk if isinstance(indata, dask_array_type): # dask if output_chunks is None: output_chunks = tuple( - [min(shp, inchnk) for shp, inchnk in zip(self.shape_out, indata.chunksize[-2:])] + min(shp, inchnk) for shp, inchnk in zip(self.shape_out, indata.chunksize[-2:]) ) if len(output_chunks) != len(self.shape_out): if len(output_chunks) == 1 and self.sequence_out: @@ -611,14 +669,14 @@ def regrid_array(self, indata, weights, skipna=False, na_thres=1.0, output_chunk outdata = self._regrid(indata, weights, **kwargs) return outdata - def regrid_numpy(self, indata, **kwargs): + def regrid_numpy(self, indata: npt.NDArray, **kwargs) -> npt.NDArray: warnings.warn( '`regrid_numpy()` will be removed in xESMF 0.7, please use `regrid_array` instead.', category=FutureWarning, ) return self.regrid_array(indata, self.weights.data, **kwargs) - def regrid_dask(self, indata, **kwargs): + def regrid_dask(self, indata: npt.NDArray, **kwargs) -> npt.NDArray: warnings.warn( '`regrid_dask()` will be removed in xESMF 0.7, please use `regrid_array` instead.', category=FutureWarning, @@ -626,8 +684,13 @@ def regrid_dask(self, indata, **kwargs): return self.regrid_array(indata, self.weights.data, **kwargs) def regrid_dataarray( - self, dr_in, keep_attrs=False, skipna=False, na_thres=1.0, output_chunks=None - ): + self, + dr_in: xr.DataArray, + keep_attrs: bool = False, + skipna: bool = False, + na_thres: float = 1.0, + output_chunks: Optional[Union[Dict[str, int], Tuple[int, ...]]] = None, + ) -> Union[DataArray, Dataset]: """See __call__().""" input_horiz_dims, temp_horiz_dims = self._parse_xrinput(dr_in) @@ -646,8 +709,13 @@ def regrid_dataarray( return self._format_xroutput(dr_out, temp_horiz_dims) def regrid_dataset( - self, ds_in, keep_attrs=False, skipna=False, na_thres=1.0, output_chunks=None - ): + self, + ds_in: xr.Dataset, + keep_attrs: bool = False, + skipna: bool = False, + na_thres: float = 1.0, + output_chunks: Optional[Union[Dict[str, int], Tuple[int, ...]]] = None, + ) -> Union[DataArray, Dataset]: """See __call__().""" # get the first data variable to infer input_core_dims @@ -675,7 +743,9 @@ def regrid_dataset( return self._format_xroutput(ds_out, temp_horiz_dims) - def _parse_xrinput(self, dr_in): + def _parse_xrinput( + self, dr_in: Union[xr.DataArray, xr.Dataset] + ) -> Tuple[Tuple[Hashable, ...], List[str]]: # dr could be a DataArray or a Dataset # Get input horiz dim names and set output horiz dim names if self.in_horiz_dims is not None and all(dim in dr_in.dims for dim in self.in_horiz_dims): @@ -702,46 +772,45 @@ def _parse_xrinput(self, dr_in): ) if self.sequence_out: - temp_horiz_dims = ['dummy', 'locations'] + temp_horiz_dims: List[str] = ['dummy', 'locations'] else: - temp_horiz_dims = [s + '_new' for s in input_horiz_dims] + temp_horiz_dims: List[str] = [s + '_new' for s in input_horiz_dims] if self.sequence_in and not self.sequence_out: temp_horiz_dims = ['dummy_new'] + temp_horiz_dims return input_horiz_dims, temp_horiz_dims - def _format_xroutput(self, out, new_dims=None): + def _format_xroutput( + self, out: Union[xr.DataArray, xr.Dataset], new_dims: Optional[List[str]] = None + ) -> Union[xr.DataArray, xr.Dataset]: out.attrs['regrid_method'] = self.method return out - def __repr__(self): + def __repr__(self) -> str: info = ( 'xESMF Regridder \n' - 'Regridding algorithm: {} \n' - 'Weight filename: {} \n' - 'Reuse pre-computed weights? {} \n' - 'Input grid shape: {} \n' - 'Output grid shape: {} \n' - 'Periodic in longitude? {}'.format( - self.method, - self.filename, - self.reuse_weights, - self.shape_in, - self.shape_out, - self.periodic, - ) + f'Regridding algorithm: {self.method} \n' + f'Weight filename: {self.filename} \n' + f'Reuse pre-computed weights? {self.reuse_weights} \n' + f'Input grid shape: {self.shape_in} \n' + f'Output grid shape: {self.shape_out} \n' + f'Periodic in longitude? {self.periodic}' ) return info - def to_netcdf(self, filename=None): + def to_netcdf(self, filename: Optional[str] = None) -> str: """Save weights to disk as a netCDF file.""" if filename is None: filename = self.filename w = self.weights.data dim = 'n_s' ds = xr.Dataset( - {'S': (dim, w.data), 'col': (dim, w.coords[1, :] + 1), 'row': (dim, w.coords[0, :] + 1)} + { + 'S': (dim, w.data), + 'col': (dim, w.coords[1, :] + 1), + 'row': (dim, w.coords[0, :] + 1), + } ) ds.to_netcdf(filename) return filename @@ -750,15 +819,22 @@ def to_netcdf(self, filename=None): class Regridder(BaseRegridder): def __init__( self, - ds_in, - ds_out, - method, - locstream_in=False, - locstream_out=False, - periodic=False, - parallel=False, + ds_in: Union[xr.DataArray, xr.Dataset, Dict[str, xr.DataArray]], + ds_out: Union[xr.DataArray, xr.Dataset, Dict[str, xr.DataArray]], + method: Literal[ + 'bilinear', + 'conservative', + 'conservative_normed', + 'patch', + 'nearest_s2d', + 'nearest_d2s', + ], + locstream_in: bool = False, + locstream_out: bool = False, + periodic: bool = False, + parallel: bool = False, **kwargs, - ): + ) -> None: """ Make xESMF regridder @@ -833,10 +909,10 @@ def __init__( weights : None, coo_matrix, dict, str, Dataset, Path, Regridding weights, stored as - - a scipy.sparse COO matrix, - - a dictionary with keys `row_dst`, `col_src` and `weights`, - - an xarray Dataset with data variables `col`, `row` and `S`, - - or a path to a netCDF file created by ESMF. + - a scipy.sparse COO matrix, + - a dictionary with keys `row_dst`, `col_src` and `weights`, + - an xarray Dataset with data variables `col`, `row` and `S`, + - or a path to a netCDF file created by ESMF. If None, compute the weights. @@ -957,7 +1033,12 @@ def __init__( if parallel: self._init_para_regrid(ds_in, ds_out, kwargs) - def _init_para_regrid(self, ds_in, ds_out, kwargs): + def _init_para_regrid( + self, + ds_in: xr.Dataset, + ds_out: xr.Dataset, + kwargs: dict, + ): # Check if we have bounds as variable and not coords, and add them to coords in both datasets if 'lon_b' in ds_out.data_vars: ds_out = ds_out.set_coords(['lon_b', 'lat_b']) @@ -1084,15 +1165,15 @@ def _format_xroutput(self, out, new_dims=None): class SpatialAverager(BaseRegridder): def __init__( self, - ds_in, - polys, - ignore_holes=False, - periodic=False, - filename=None, - reuse_weights=False, - weights=None, - ignore_degenerate=False, - geom_dim_name='geom', + ds_in: Union[xr.DataArray, xr.Dataset, dict], + polys: Sequence[Union[Polygon, MultiPolygon]], + ignore_holes: bool = False, + periodic: bool = False, + filename: Optional[str] = None, + reuse_weights: bool = False, + weights: Optional[Union[sps.COO, dict, str, Dataset]] = None, + ignore_degenerate: bool = False, + geom_dim_name: str = 'geom', ): """Compute the exact average of a gridded array over a geometry. @@ -1215,7 +1296,7 @@ def __init__( ) @staticmethod - def _check_polys_length(polys, threshold=1): + def _check_polys_length(polys: List[Polygon], threshold: int = 1) -> None: # Check length of polys segments, issue warning if too long check_polys, check_holes, _, _ = split_polygons_and_holes(polys) check_polys.extend(check_holes) @@ -1231,7 +1312,7 @@ def _check_polys_length(polys, threshold=1): stacklevel=2, ) - def _compute_weights_and_area(self, mesh_out): + def _compute_weights_and_area(self, mesh_out: Mesh) -> Tuple[DataArray, Any]: """Return the weights and the area of the destination mesh cells.""" # Build the regrid object @@ -1253,12 +1334,12 @@ def _compute_weights_and_area(self, mesh_out): esmf_regrid_finalize(regrid) return w, dstarea - def _compute_weights(self): + def _compute_weights(self) -> DataArray: """Return weight sparse matrix. This function first explodes the geometries into a flat list of Polygon exterior objects: - - Polygon -> polygon.exterior - - MultiPolygon -> list of polygon.exterior + - Polygon -> polygon.exterior + - MultiPolygon -> list of polygon.exterior and a list of Polygon.interiors (holes). @@ -1310,7 +1391,7 @@ def w(self) -> xr.DataArray: dims = self.geom_dim_name, 'y_in', 'x_in' return xr.DataArray(data, dims=dims) - def _get_default_filename(self): + def _get_default_filename(self) -> str: # e.g. bilinear_400x600_300x400.nc filename = 'spatialavg_{0}x{1}_{2}.nc'.format( self.shape_in[0], self.shape_in[1], self.n_out @@ -1318,20 +1399,20 @@ def _get_default_filename(self): return filename - def __repr__(self): + def __repr__(self) -> str: info = ( - 'xESMF SpatialAverager \n' - 'Weight filename: {} \n' - 'Reuse pre-computed weights? {} \n' - 'Input grid shape: {} \n' - 'Output list length: {} \n'.format( - self.filename, self.reuse_weights, self.shape_in, self.n_out - ) + f'xESMF SpatialAverager \n' + f'Weight filename: {self.filename} \n' + f'Reuse pre-computed weights: {self.reuse_weights} \n' + f'Input grid shape: {self.shape_in} \n' + f'Output list length: {self.n_out} \n' ) return info - def _format_xroutput(self, out, new_dims=None): + def _format_xroutput( + self, out: Union[DataArray, Dataset], new_dims=None + ) -> Union[DataArray, Dataset]: out = out.squeeze(dim='dummy') # rename dimension name to match output grid diff --git a/xesmf/smm.py b/xesmf/smm.py index a94bacb6..9531fc45 100644 --- a/xesmf/smm.py +++ b/xesmf/smm.py @@ -3,21 +3,27 @@ """ import warnings from pathlib import Path +from typing import Any, Dict, Tuple, Union -import numba as nb +import numba as nb # type: ignore[import] import numpy as np -import sparse as sps +import numpy.typing as npt +import sparse as sps # type: ignore[import] import xarray as xr -def read_weights(weights, n_in, n_out): +def read_weights( + weights: Union[str, Path, xr.Dataset, xr.DataArray, sps.COO, Dict[str, Any]], + n_in: int, + n_out: int, +) -> xr.DataArray: """ Read regridding weights into a DataArray (sparse COO matrix). Parameters ---------- weights : str, Path, xr.Dataset, xr.DataArray, sparse.COO - Weights generated by ESMF. Can be a path to a netCDF file generated by ESMF, an xarray.Dataset, + Weights generated by ESMF. Can be a path to a netCDF file generated by ESMF, an xr.Dataset, a dictionary created by `ESMPy.api.Regrid.get_weights_dict` or directly the sparse array as returned by this function. @@ -25,8 +31,8 @@ def read_weights(weights, n_in, n_out): ``(N_out, N_in)`` will be the shape of the returning sparse matrix. They are the total number of grid boxes in input and output grids:: - N_in = Nx_in * Ny_in - N_out = Nx_out * Ny_out + N_in = Nx_in * Ny_in + N_out = Nx_out * Ny_out We need them because the shape cannot always be inferred from the largest column and row indices, due to unmapped grid boxes. @@ -34,63 +40,67 @@ def read_weights(weights, n_in, n_out): Returns ------- xr.DataArray - A DataArray backed by a sparse.COO array, with dims ('out_dim', 'in_dim') - and size (n_out, n_in). + A DataArray backed by a sparse.COO array, with dims ('out_dim', 'in_dim') + and size (n_out, n_in). """ if isinstance(weights, (str, Path, xr.Dataset, dict)): - weights = _parse_coords_and_values(weights, n_in, n_out) + return _parse_coords_and_values(weights, n_in, n_out) - elif isinstance(weights, sps.COO): - weights = xr.DataArray(weights, dims=('out_dim', 'in_dim'), name='weights') + if isinstance(weights, sps.COO): + return xr.DataArray(weights, dims=('out_dim', 'in_dim'), name='weights') - elif not isinstance(weights, xr.DataArray): - raise ValueError(f'Weights of type {type(weights)} not understood.') + if isinstance(weights, xr.DataArray): # type: ignore[no-untyped-def] + return weights - return weights + raise ValueError(f'Weights of type {type(weights)} not understood.') -def _parse_coords_and_values(indata, n_in, n_out): +def _parse_coords_and_values( + indata: Union[str, Path, xr.Dataset, Dict[str, Any]], + n_in: int, + n_out: int, +) -> xr.DataArray: """Creates a sparse.COO array from weights stored in a dict-like fashion. Parameters ---------- indata: str, Path, xr.Dataset or dict - A dictionary as returned by ESMF.Regrid.get_weights_dict - or an xarray Dataset (or its path) as saved by xESMF. + A dictionary as returned by ESMF.Regrid.get_weights_dict + or an xarray Dataset (or its path) as saved by xESMF. n_in : int - The number of points in the input grid. + The number of points in the input grid. n_out : int - The number of points in the output grid. + The number of points in the output grid. Returns ------- sparse.COO - Sparse array in the COO format. + Sparse array in the COO format. """ if isinstance(indata, (str, Path, xr.Dataset)): if not isinstance(indata, xr.Dataset): if not Path(indata).exists(): raise IOError(f'Weights file not found on disk.\n{indata}') - ds_w = xr.open_dataset(indata) + ds_w = xr.open_dataset(indata) # type: ignore[no-untyped-def] else: ds_w = indata if not {'col', 'row', 'S'}.issubset(ds_w.variables): raise ValueError( - 'Weights dataset should have variables `col`, `row` and `S` storing the indices and ' - 'values of weights.' + 'Weights dataset should have variables `col`, `row` and `S` storing the indices ' + 'and values of weights.' ) - col = ds_w['col'].values - 1 # Python starts with 0 - row = ds_w['row'].values - 1 - s = ds_w['S'].values + col = ds_w['col'].values - 1 # type: ignore[no-untyped-def] + row = ds_w['row'].values - 1 # type: ignore[no-untyped-def] + s = ds_w['S'].values # type: ignore[no-untyped-def] - elif isinstance(indata, dict): + elif isinstance(indata, dict): # type: ignore if not {'col_src', 'row_dst', 'weights'}.issubset(indata.keys()): raise ValueError( - 'Weights dictionary should have keys `col_src`, `row_dst` and `weights` storing the ' - 'indices and values of weights.' + 'Weights dictionary should have keys `col_src`, `row_dst` and `weights` storing ' + 'the indices and values of weights.' ) col = indata['col_src'] - 1 row = indata['row_dst'] - 1 @@ -100,28 +110,33 @@ def _parse_coords_and_values(indata, n_in, n_out): return xr.DataArray(sps.COO(crds, s, (n_out, n_in)), dims=('out_dim', 'in_dim'), name='weights') -def check_shapes(indata, weights, shape_in, shape_out): +def check_shapes( + indata: npt.NDArray, + weights: npt.NDArray, + shape_in: Tuple[int, int], + shape_out: Tuple[int, int], +) -> None: """Compare the shapes of the input array, the weights and the regridder and raises potential errors. Parameters ---------- indata : array - Input array with the two spatial dimensions at the end, - which should fit shape_in. + Input array with the two spatial dimensions at the end, + which should fit shape_in. weights : array - Weights 2D array of shape (out_dim, in_dim). - First element should be the product of shape_out. - Second element should be the product of shape_in. + Weights 2D array of shape (out_dim, in_dim). + First element should be the product of shape_out. + Second element should be the product of shape_in. shape_in : 2-tuple of int - Shape of the input of the Regridder. + Shape of the input of the Regridder. shape_out : 2-tuple of int - Shape of the output of the Regridder. + Shape of the output of the Regridder. Raises ------ ValueError - If any of the conditions is not respected. + If any of the conditions is not respected. """ # COO matrix is fast with F-ordered array but slow with C-array, so we # take in a C-ordered and then transpose) @@ -131,9 +146,9 @@ def check_shapes(indata, weights, shape_in, shape_out): # Limitation from numba : some big-endian dtypes are not supported. try: - nb.from_dtype(indata.dtype) - nb.from_dtype(weights.dtype) - except (NotImplementedError, nb.core.errors.NumbaError): + nb.from_dtype(indata.dtype) # type: ignore + nb.from_dtype(weights.dtype) # type: ignore + except (NotImplementedError, nb.core.errors.NumbaError): # type: ignore warnings.warn( 'Input array has a dtype not supported by sparse and numba.' 'Computation will fall back to scipy.' @@ -155,16 +170,21 @@ def check_shapes(indata, weights, shape_in, shape_out): raise ValueError('ny_out * nx_out should equal to weights.shape[0]') -def apply_weights(weights, indata, shape_in, shape_out): +def apply_weights( + weights: sps.COO, + indata: npt.NDArray, + shape_in: Tuple[int, int], + shape_out: Tuple[int, int], +) -> npt.NDArray: """ Apply regridding weights to data. Parameters ---------- weights : sparse COO matrix - Regridding weights. + Regridding weights. indata : numpy array of shape ``(..., n_lat, n_lon)`` or ``(..., n_y, n_x)``. - Should be C-ordered. Will be then tranposed to F-ordered. + Should be C-ordered. Will be then transposed to F-ordered. shape_in, shape_out : tuple of two integers Input/output data shape. For rectilinear grid, it is just ``(n_lat, n_lon)``. @@ -180,9 +200,9 @@ def apply_weights(weights, indata, shape_in, shape_out): # Limitation from numba : some big-endian dtypes are not supported. indata_dtype = indata.dtype try: - nb.from_dtype(indata.dtype) - nb.from_dtype(weights.dtype) - except (NotImplementedError, nb.core.errors.NumbaError): + nb.from_dtype(indata.dtype) # type: ignore + nb.from_dtype(weights.dtype) # type: ignore + except (NotImplementedError, nb.core.errors.NumbaError): # type: ignore indata = indata.astype(' xr.DataArray: """Add NaN in empty rows of the regridding weights sparse matrix. By default, empty rows in the weights sparse matrix are interpreted as zeroes. This can become problematic @@ -210,12 +230,12 @@ def add_nans_to_weights(weights): Parameters ---------- weights : DataArray backed by a sparse.COO array - Sparse weights matrix. + Sparse weights matrix. Returns ------- DataArray backed by a sparse.COO array - Sparse weights matrix. + Sparse weights matrix. """ # Taken from @trondkr and adapted by @raphaeldussin to use `lil`. @@ -226,12 +246,17 @@ def add_nans_to_weights(weights): for krow in range(len(m.rows)): m.rows[krow] = [0] if m.rows[krow] == [] else m.rows[krow] m.data[krow] = [np.NaN] if m.data[krow] == [] else m.data[krow] + # update regridder weights (in COO) - weights = weights.copy(data=sps.COO.from_scipy_sparse(m)) + weights = weights.copy(data=sps.COO.from_scipy_sparse(m)) # type: ignore return weights -def _combine_weight_multipoly(weights, areas, indexes): +def _combine_weight_multipoly( # type: ignore + weights: xr.DataArray, + areas: npt.NDArray, + indexes: npt.NDArray, +) -> xr.DataArray: """Reduce a weight sparse matrix (csc format) by combining (adding) columns. This is used to sum individual weight matrices from multi-part geometries. @@ -239,17 +264,17 @@ def _combine_weight_multipoly(weights, areas, indexes): Parameters ---------- weights : DataArray - Usually backed by a sparse.COO array, with dims ('out_dim', 'in_dim') + Usually backed by a sparse.COO array, with dims ('out_dim', 'in_dim') areas : np.array - Array of destination areas, following same order as weights. + Array of destination areas, following same order as weights. indexes : array of integers - Columns with the same "index" will be summed into a single column at this - index in the output matrix. + Columns with the same 'index' will be summed into a single column at this + index in the output matrix. Returns ------- sparse matrix (CSC) - Sum of weights from individual geometries. + Sum of weights from individual geometries. """ sub_weights = weights.rename(out_dim='subgeometries') diff --git a/xesmf/util.py b/xesmf/util.py index 4790b112..fa921b45 100644 --- a/xesmf/util.py +++ b/xesmf/util.py @@ -1,6 +1,8 @@ import warnings +from typing import Any, Generator, List, Literal, Tuple, Union import numpy as np +import numpy.typing as npt import xarray as xr from shapely.geometry import MultiPolygon, Polygon @@ -8,7 +10,7 @@ LAT_CF_ATTRS = {'standard_name': 'latitude', 'units': 'degrees_north'} -def _grid_1d(start_b, end_b, step): +def _grid_1d(start_b: float, end_b: float, step: float) -> Tuple[npt.NDArray, npt.NDArray]: """ 1D grid centers and bounds @@ -33,7 +35,14 @@ def _grid_1d(start_b, end_b, step): return centers, bounds -def grid_2d(lon0_b, lon1_b, d_lon, lat0_b, lat1_b, d_lat): +def grid_2d( + lon0_b: float, + lon1_b: float, + d_lon: float, + lat0_b: float, + lat1_b: float, + d_lat: float, +) -> xr.Dataset: """ 2D rectilinear grid centers and bounds @@ -75,7 +84,14 @@ def grid_2d(lon0_b, lon1_b, d_lon, lat0_b, lat1_b, d_lat): return ds -def cf_grid_2d(lon0_b, lon1_b, d_lon, lat0_b, lat1_b, d_lat): +def cf_grid_2d( + lon0_b: float, + lon1_b: float, + d_lon: float, + lat0_b: float, + lat1_b: float, + d_lat: float, +) -> xr.Dataset: """ CF compliant 2D rectilinear grid centers and bounds. @@ -126,21 +142,26 @@ def cf_grid_2d(lon0_b, lon1_b, d_lon, lat0_b, lat1_b, d_lat): return ds -def grid_global(d_lon, d_lat, cf=False, lon1=180): +def grid_global( + d_lon: float, + d_lat: float, + cf: bool = False, + lon1: Literal[180, 360] = 180, +) -> xr.Dataset: """ Global 2D rectilinear grid centers and bounds Parameters ---------- d_lon : float - Longitude step size, i.e. grid resolution + Longitude step size, i.e. grid resolution d_lat : float - Latitude step size, i.e. grid resolution + Latitude step size, i.e. grid resolution cf : bool - Return a CF compliant grid. + Return a CF compliant grid. lon1 : {180, 360} - Right longitude bound. According to which convention is used longitudes will - vary from -180 to 180 or from 0 to 360. + Right longitude bound. According to which convention is used longitudes will + vary from -180 to 180 or from 0 to 360. Returns ------- @@ -150,14 +171,12 @@ def grid_global(d_lon, d_lat, cf=False, lon1=180): if not np.isclose(360 / d_lon, 360 // d_lon): warnings.warn( - '360 cannot be divided by d_lon = {}, ' - 'might not cover the globe uniformly'.format(d_lon) + f'360 cannot be divided by d_lon = {d_lon}, might not cover the globe uniformly' ) if not np.isclose(180 / d_lat, 180 // d_lat): warnings.warn( - '180 cannot be divided by d_lat = {}, ' - 'might not cover the globe uniformly'.format(d_lat) + f'180 cannot be divided by d_lat = {d_lat}, might not cover the globe uniformly' ) lon0 = lon1 - 360 @@ -168,7 +187,9 @@ def grid_global(d_lon, d_lat, cf=False, lon1=180): return grid_2d(lon0, lon1, d_lon, -90, 90, d_lat) -def _flatten_poly_list(polys): +def _flatten_poly_list( + polys: List[Polygon], +) -> Generator[Union[Tuple[int, Any], Tuple[int, Polygon]], Any, None]: """Iterator flattening MultiPolygons.""" for i, poly in enumerate(polys): if isinstance(poly, MultiPolygon): @@ -178,7 +199,9 @@ def _flatten_poly_list(polys): yield (i, poly) -def split_polygons_and_holes(polys): +def split_polygons_and_holes( + polys: List[Polygon], +) -> Tuple[List[Polygon], List[Polygon], List[int], List[int]]: """Split the exterior boundaries and the holes for a list of polygons. If MultiPolygons are encountered in the list, they are flattened out @@ -195,14 +218,14 @@ def split_polygons_and_holes(polys): holes : list of Polygons Holes of the polygons as polygons i_ext : list of integers - The index in `polys` of each polygon in `exteriors`. + The index in `polys` of each polygon in `exteriors`. i_hol : list of integers - The index in `polys` of the owner of each hole in `holes`. + The index in `polys` of the owner of each hole in `holes`. """ - exteriors = [] - holes = [] - i_ext = [] - i_hol = [] + exteriors: List[Polygon] = [] + holes: List[Polygon] = [] + i_ext: List[int] = [] + i_hol: List[int] = [] for i, poly in _flatten_poly_list(polys): exteriors.append(Polygon(poly.exterior)) i_ext.append(i) @@ -218,19 +241,24 @@ def split_polygons_and_holes(polys): HUGE = 1.0e30 -def simple_tripolar_grid(nlons, nlats, lat_cap=60, lon_cut=-300): +def simple_tripolar_grid( + nlons: int, + nlats: int, + lat_cap: float = 60, + lon_cut: float = -300, +) -> Tuple[npt.NDArray, npt.NDArray]: """Generate a simple tripolar grid, regular under `lat_cap`. Parameters ---------- nlons: int - Number of longitude points. + Number of longitude points. nlats: int - Number of latitude points. + Number of latitude points. lat_cap: float - Latitude of the northern cap. + Latitude of the northern cap. lon_cut: float - Longitude of the periodic boundary. + Longitude of the periodic boundary. """ @@ -258,7 +286,9 @@ def simple_tripolar_grid(nlons, nlats, lat_cap=60, lon_cut=-300): # rather than using the package as a dependency -def _bipolar_projection(lamg, phig, lon_bp, rp, metrics_only=False): +def _bipolar_projection( + lamg: float, phig: float, lon_bp: float, rp: float, metrics_only: bool = False +): """Makes a stereographic bipolar projection of the input coordinate mesh (lamg,phig) Returns the projected coordinate mesh and their metric coefficients (h^-1). The input mesh must be a regular spherical grid capping the pole with: @@ -280,7 +310,7 @@ def _bipolar_projection(lamg, phig, lon_bp, rp, metrics_only=False): B = np.where(np.abs(beta2_inv) > HUGE, 0.0, B) lamc = np.arcsin(B) / PI_180 # But this equation accepts 4 solutions for a given B, {l, 180-l, l+180, 360-l } - # We have to pickup the "correct" root. + # We have to pickup the 'correct' root. # One way is simply to demand lamc to be continuous with lam on the equator phi=0 # I am sure there is a more mathematically concrete way to do this. lamc = np.where((lamg - lon_bp > 90) & (lamg - lon_bp <= 180), 180 - lamc, lamc) @@ -328,14 +358,20 @@ def _bipolar_projection(lamg, phig, lon_bp, rp, metrics_only=False): return h_i_inv, h_j_inv -def _generate_bipolar_cap_mesh(Ni, Nj_ncap, lat0_bp, lon_bp, ensure_nj_even=True): +def _generate_bipolar_cap_mesh( + Ni: float, + Nj_ncap: float, + lat0_bp: float, + lon_bp: float, + ensure_nj_even: bool = True, +): # Define a (lon,lat) coordinate mesh on the Northern hemisphere of the globe sphere # such that the resolution of latg matches the desired resolution of the final grid along the symmetry meridian print('Generating bipolar grid bounded at latitude ', lat0_bp) if Nj_ncap % 2 != 0 and ensure_nj_even: print(' Supergrid has an odd number of area cells!') if ensure_nj_even: - print(" The number of j's is not even. Fixing this by cutting one row.") + print(' The number of j\'s is not even. Fixing this by cutting one row.') Nj_ncap = Nj_ncap - 1 lon_g = lon_bp + np.arange(Ni + 1) * 360.0 / float(Ni) @@ -350,7 +386,7 @@ def _generate_bipolar_cap_mesh(Ni, Nj_ncap, lat0_bp, lon_bp, ensure_nj_even=True return lams, phis, h_i_inv, h_j_inv -def _mdist(x1, x2): +def _mdist(x1: float, x2: float) -> float: """Returns positive distance modulo 360.""" return np.minimum(np.mod(x1 - x2, 360.0), np.mod(x2 - x1, 360.0))