diff --git a/src/quemb/shared/typing.py b/src/quemb/shared/typing.py index 643eeb97..8457420d 100644 --- a/src/quemb/shared/typing.py +++ b/src/quemb/shared/typing.py @@ -11,12 +11,18 @@ import numpy as np -T_co = TypeVar("T_co", bound=np.generic, covariant=True) +# We want the dtype to behave covariant, i.e. if a +# Vector[float] is allowed, then the more specific +# Vector[float64] should also be allowed. +# +# Also see here: +# https://stackoverflow.com/questions/61568462/what-does-typevara-b-covariant-true-mean +T_dtype_co = TypeVar("T_dtype_co", bound=np.generic, covariant=True) -Vector = np.ndarray[Tuple[int], np.dtype[T_co]] -Matrix = np.ndarray[Tuple[int, int], np.dtype[T_co]] -Tensor3D = np.ndarray[Tuple[int, int, int], np.dtype[T_co]] -Tensor4D = np.ndarray[Tuple[int, int, int, int], np.dtype[T_co]] -Tensor5D = np.ndarray[Tuple[int, int, int, int, int], np.dtype[T_co]] -Tensor6D = np.ndarray[Tuple[int, int, int, int, int, int], np.dtype[T_co]] -Tensor = np.ndarray[Tuple[int, ...], np.dtype[T_co]] +Vector = np.ndarray[Tuple[int], np.dtype[T_dtype_co]] +Matrix = np.ndarray[Tuple[int, int], np.dtype[T_dtype_co]] +Tensor3D = np.ndarray[Tuple[int, int, int], np.dtype[T_dtype_co]] +Tensor4D = np.ndarray[Tuple[int, int, int, int], np.dtype[T_dtype_co]] +Tensor5D = np.ndarray[Tuple[int, int, int, int, int], np.dtype[T_dtype_co]] +Tensor6D = np.ndarray[Tuple[int, int, int, int, int, int], np.dtype[T_dtype_co]] +Tensor = np.ndarray[Tuple[int, ...], np.dtype[T_dtype_co]]