diff --git a/src/quemb/shared/typing.py b/src/quemb/shared/typing.py new file mode 100644 index 00000000..5f40d527 --- /dev/null +++ b/src/quemb/shared/typing.py @@ -0,0 +1,27 @@ +"""Enable barebone typechecking for the shape of numpy arrays + +Inspired by +https://stackoverflow.com/questions/75495212/type-hinting-numpy-arrays-and-batches + +Note that most numpy functions return `ndarray[Any, Any]` +i.e. the type is mostly useful to document intent to the developer. +""" + +from typing import Tuple, TypeVar + +import numpy as np + +# We want the dtype to behave covariant, i.e. if a +# Vector[float] is allowed, then the more specific +# Vector[float64] should also be allowed. +# Also see here: +# https://stackoverflow.com/questions/61568462/what-does-typevara-b-covariant-true-mean +T_dtype_co = TypeVar("T_dtype_co", bound=np.generic, covariant=True) + +Vector = np.ndarray[Tuple[int], np.dtype[T_dtype_co]] +Matrix = np.ndarray[Tuple[int, int], np.dtype[T_dtype_co]] +Tensor3D = np.ndarray[Tuple[int, int, int], np.dtype[T_dtype_co]] +Tensor4D = np.ndarray[Tuple[int, int, int, int], np.dtype[T_dtype_co]] +Tensor5D = np.ndarray[Tuple[int, int, int, int, int], np.dtype[T_dtype_co]] +Tensor6D = np.ndarray[Tuple[int, int, int, int, int, int], np.dtype[T_dtype_co]] +Tensor = np.ndarray[Tuple[int, ...], np.dtype[T_dtype_co]]