Skip to content

Commit

Permalink
initial typing
Browse files Browse the repository at this point in the history
  • Loading branch information
hollymandel committed Oct 4, 2024
1 parent 5c64182 commit c3e7a48
Showing 1 changed file with 54 additions and 29 deletions.
83 changes: 54 additions & 29 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable, Hashable, Sequence
from functools import partial
from numbers import Number
from typing import TYPE_CHECKING, Any, get_args
from typing import TYPE_CHECKING, Any, Optional, get_args

import numpy as np
import pandas as pd
Expand All @@ -29,11 +29,12 @@
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.variable import IndexVariable


def _get_nan_block_lengths(
obj: Dataset | DataArray | Variable, dim: Hashable, index: Variable
):
obj: Dataset | DataArray, dim: Hashable, index: Variable
) -> Any:
"""
Return an object where each NaN element in 'obj' is replaced by the
length of the gap the element is in.
Expand Down Expand Up @@ -66,12 +67,12 @@ class BaseInterpolator:
cons_kwargs: dict[str, Any]
call_kwargs: dict[str, Any]
f: Callable
method: str
method: str | int

def __call__(self, x):
def __call__(self, x: np.ndarray) -> np.ndarray:
return self.f(x, **self.call_kwargs)

def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}: method={self.method}"


Expand All @@ -83,7 +84,14 @@ class NumpyInterpolator(BaseInterpolator):
numpy.interp
"""

def __init__(self, xi, yi, method="linear", fill_value=None, period=None):
def __init__(
self,
xi: Variable,
yi: np.ndarray,
method: Optional[str] = "linear",
fill_value=None,
period=None,
):
if method != "linear":
raise ValueError("only method `linear` is valid for the NumpyInterpolator")

Expand All @@ -104,8 +112,8 @@ def __init__(self, xi, yi, method="linear", fill_value=None, period=None):
self._left = fill_value[0]
self._right = fill_value[1]
elif is_scalar(fill_value):
self._left = fill_value
self._right = fill_value
self._left = fill_value # type: ignore[assignment]
self._right = fill_value # type: ignore[assignment]
else:
raise ValueError(f"{fill_value} is not a valid fill_value")

Expand All @@ -130,14 +138,14 @@ class ScipyInterpolator(BaseInterpolator):

def __init__(
self,
xi,
yi,
method=None,
fill_value=None,
assume_sorted=True,
copy=False,
bounds_error=False,
order=None,
xi: Variable,
yi: np.ndarray,
method: Optional[str | int] = None,
fill_value: Optional[float | complex] = None,
assume_sorted: bool = True,
copy: bool = False,
bounds_error: bool = False,
order: Optional[int] = None,
axis=-1,
**kwargs,
):
Expand All @@ -154,18 +162,13 @@ def __init__(
raise ValueError("order is required when method=polynomial")
method = order

self.method = method
self.method: str | int = method

self.cons_kwargs = kwargs
self.call_kwargs = {}

nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j

if fill_value is None and method == "linear":
fill_value = nan, nan
elif fill_value is None:
fill_value = nan

self.f = interp1d(
xi,
yi,
Expand Down Expand Up @@ -601,7 +604,12 @@ def _floatize_x(x, new_x):
return x, new_x


def interp(var, indexes_coords, method: InterpOptions, **kwargs):
def interp(
var: Variable,
indexes_coords: dict[str, IndexVariable],
method: InterpOptions,
**kwargs,
) -> Variable:
"""Make an interpolation of Variable
Parameters
Expand Down Expand Up @@ -662,7 +670,13 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
return result


def interp_func(var, x, new_x, method: InterpOptions, kwargs):
def interp_func(
var: np.ndarray,
x: list[IndexVariable],
new_x: list[IndexVariable],
method: InterpOptions,
kwargs: dict,
) -> np.ndarray:
"""
multi-dimensional interpolation for array-like. Interpolated axes should be
located in the last position.
Expand Down Expand Up @@ -766,9 +780,14 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
return _interpnd(var, x, new_x, func, kwargs)


def _interp1d(var, x, new_x, func, kwargs):
def _interp1d(
var: np.ndarray,
x: IndexVariable,
new_x: IndexVariable,
func: Callable,
kwargs: dict,
) -> np.ndarray:
# x, new_x are tuples of size 1.
x, new_x = x[0], new_x[0]
rslt = func(x, var, **kwargs)(np.ravel(new_x))
if new_x.ndim > 1:
return reshape(rslt, (var.shape[:-1] + new_x.shape))
Expand All @@ -777,11 +796,17 @@ def _interp1d(var, x, new_x, func, kwargs):
return rslt


def _interpnd(var, x, new_x, func, kwargs):
def _interpnd(
var: np.ndarray,
x: list[IndexVariable],
new_x: list[IndexVariable],
func: Callable,
kwargs: dict,
) -> np.ndarray:
x, new_x = _floatize_x(x, new_x)

if len(x) == 1:
return _interp1d(var, x, new_x, func, kwargs)
return _interp1d(var, x[0], new_x[0], func, kwargs)

# move the interpolation axes to the start position
var = var.transpose(range(-len(x), var.ndim - len(x)))
Expand Down

0 comments on commit c3e7a48

Please sign in to comment.