diff --git a/braintaichi/_eventop/_event_csrmm.py b/braintaichi/_eventop/_event_csrmm.py index 2fb3516..5cc361b 100644 --- a/braintaichi/_eventop/_event_csrmm.py +++ b/braintaichi/_eventop/_event_csrmm.py @@ -15,9 +15,7 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - -from typing import Tuple +from typing import Tuple, Union import brainunit as u import jax @@ -34,7 +32,7 @@ def raw_event_csrmm_taichi( - data: jax.typing.ArrayLike | u.Quantity, + data: Union[jax.typing.ArrayLike, u.Quantity], indices: jax.typing.ArrayLike, indptr: jax.typing.ArrayLike, matrix: jax.typing.ArrayLike, diff --git a/braintaichi/_eventop/_event_csrmv.py b/braintaichi/_eventop/_event_csrmv.py index 6da3d7f..cb9e1f5 100644 --- a/braintaichi/_eventop/_event_csrmv.py +++ b/braintaichi/_eventop/_event_csrmv.py @@ -25,8 +25,6 @@ """ -from __future__ import annotations - from typing import Union, Tuple import jax diff --git a/braintaichi/_eventop/main.py b/braintaichi/_eventop/main.py index 2abc0d1..3d9a747 100644 --- a/braintaichi/_eventop/main.py +++ b/braintaichi/_eventop/main.py @@ -14,8 +14,6 @@ # ============================================================================== -from __future__ import annotations - from typing import Union, Tuple import brainunit as u @@ -33,7 +31,7 @@ def event_csrmm( - data: jax.typing.ArrayLike | u.Quantity, + data: Union[jax.typing.ArrayLike, u.Quantity], indices: jax.typing.ArrayLike, indptr: jax.typing.ArrayLike, matrix: jax.typing.ArrayLike, @@ -61,7 +59,7 @@ def event_csrmm( def event_csrmv( - data: Union[float, jax.Array], + data: Union[jax.typing.ArrayLike, u.Quantity], indices: jax.Array, indptr: jax.Array, events: jax.Array, diff --git a/braintaichi/_jitconnop/_jit_csrmv.py b/braintaichi/_jitconnop/_jit_csrmv.py index 3bceabb..63a0fd1 100644 --- a/braintaichi/_jitconnop/_jit_csrmv.py +++ b/braintaichi/_jitconnop/_jit_csrmv.py @@ -15,8 +15,6 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - import numbers from typing import Tuple, Optional diff --git a/braintaichi/_jitconnop/_jit_event_csrmv.py b/braintaichi/_jitconnop/_jit_event_csrmv.py index e81134b..bf556a1 100644 --- a/braintaichi/_jitconnop/_jit_event_csrmv.py +++ b/braintaichi/_jitconnop/_jit_event_csrmv.py @@ -16,8 +16,6 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - from typing import Tuple import jax diff --git a/braintaichi/_jitconnop/_taichi_rand.py b/braintaichi/_jitconnop/_taichi_rand.py index 116f6e6..2edc796 100644 --- a/braintaichi/_jitconnop/_taichi_rand.py +++ b/braintaichi/_jitconnop/_taichi_rand.py @@ -15,8 +15,6 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - import brainstate as bst import jax import taichi as ti diff --git a/braintaichi/_jitconnop/main.py b/braintaichi/_jitconnop/main.py index 7dc9356..7d26e8e 100644 --- a/braintaichi/_jitconnop/main.py +++ b/braintaichi/_jitconnop/main.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations - from typing import Tuple, Optional import jax diff --git a/braintaichi/_primitive/_ad_support.py b/braintaichi/_primitive/_ad_support.py index 8566739..33c27bf 100644 --- a/braintaichi/_primitive/_ad_support.py +++ b/braintaichi/_primitive/_ad_support.py @@ -15,8 +15,6 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - import functools from functools import partial diff --git a/braintaichi/_primitive/_batch_utils.py b/braintaichi/_primitive/_batch_utils.py index 42f9408..fd36c66 100644 --- a/braintaichi/_primitive/_batch_utils.py +++ b/braintaichi/_primitive/_batch_utils.py @@ -15,8 +15,6 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - from functools import partial import jax.numpy as jnp diff --git a/braintaichi/_primitive/_mlir_translation_rule.py b/braintaichi/_primitive/_mlir_translation_rule.py index 93c3c0e..550b764 100644 --- a/braintaichi/_primitive/_mlir_translation_rule.py +++ b/braintaichi/_primitive/_mlir_translation_rule.py @@ -15,8 +15,6 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - import contextlib import hashlib import inspect diff --git a/braintaichi/_primitive/_xla_custom_op.py b/braintaichi/_primitive/_xla_custom_op.py index cee5fe0..abc6c97 100644 --- a/braintaichi/_primitive/_xla_custom_op.py +++ b/braintaichi/_primitive/_xla_custom_op.py @@ -15,8 +15,6 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - from functools import partial from typing import Callable, Sequence, Tuple, Protocol, Optional, Union diff --git a/braintaichi/_sparseop/_sparse_coomv.py b/braintaichi/_sparseop/_sparse_coomv.py index 307e933..7434006 100644 --- a/braintaichi/_sparseop/_sparse_coomv.py +++ b/braintaichi/_sparseop/_sparse_coomv.py @@ -15,8 +15,6 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - import warnings from functools import partial diff --git a/braintaichi/_sparseop/_sparse_csrmm.py b/braintaichi/_sparseop/_sparse_csrmm.py index cacafef..2f02f98 100644 --- a/braintaichi/_sparseop/_sparse_csrmm.py +++ b/braintaichi/_sparseop/_sparse_csrmm.py @@ -15,9 +15,7 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - -from typing import Tuple +from typing import Tuple, Union import brainunit as u import jax @@ -32,7 +30,7 @@ def raw_csrmm_taichi( - data: jax.typing.ArrayLike | u.Quantity, + data: Union[jax.typing.ArrayLike, u.Quantity], indices: jax.typing.ArrayLike, indptr: jax.typing.ArrayLike, matrix: jax.typing.ArrayLike, diff --git a/braintaichi/_sparseop/_sparse_csrmv.py b/braintaichi/_sparseop/_sparse_csrmv.py index 1a17cfe..7b78131 100644 --- a/braintaichi/_sparseop/_sparse_csrmv.py +++ b/braintaichi/_sparseop/_sparse_csrmv.py @@ -16,9 +16,7 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - -from typing import Tuple +from typing import Tuple, Union import brainunit as u import jax @@ -33,7 +31,7 @@ def raw_csrmv_taichi( - data: jax.typing.ArrayLike | u.Quantity, + data: Union[jax.typing.ArrayLike, u.Quantity], indices: jax.typing.ArrayLike, indptr: jax.typing.ArrayLike, vector: jax.typing.ArrayLike, diff --git a/braintaichi/_sparseop/_sparse_utils.py b/braintaichi/_sparseop/_sparse_utils.py index 4817b1d..5a7ba31 100644 --- a/braintaichi/_sparseop/_sparse_utils.py +++ b/braintaichi/_sparseop/_sparse_utils.py @@ -15,8 +15,6 @@ # -*- coding: utf-8 -*- -from __future__ import annotations - import warnings from typing import Tuple diff --git a/braintaichi/_sparseop/main.py b/braintaichi/_sparseop/main.py index 660b722..16a51b6 100644 --- a/braintaichi/_sparseop/main.py +++ b/braintaichi/_sparseop/main.py @@ -13,10 +13,7 @@ # limitations under the License. # ============================================================================== - -from __future__ import annotations - -from typing import Tuple +from typing import Tuple, Union import brainunit as u import jax @@ -36,7 +33,7 @@ @set_module_as('braintaichi') def coomv( - data: jax.typing.ArrayLike | u.Quantity, + data: Union[jax.typing.ArrayLike, u.Quantity], row: jax.typing.ArrayLike, col: jax.typing.ArrayLike, vector: jax.typing.ArrayLike, @@ -110,7 +107,7 @@ def coomv( @set_module_as('braintaichi') def csrmm( - data: jax.typing.ArrayLike | u.Quantity, + data: Union[jax.typing.ArrayLike, u.Quantity], indices: jax.typing.ArrayLike, indptr: jax.typing.ArrayLike, matrix: jax.typing.ArrayLike, @@ -140,7 +137,7 @@ def csrmm( @set_module_as('braintaichi') def csrmv( - data: jax.typing.ArrayLike | u.Quantity, + data: Union[jax.typing.ArrayLike, u.Quantity], indices: jax.typing.ArrayLike, indptr: jax.typing.ArrayLike, vector: jax.typing.ArrayLike,