From f502aee5824ca820fba42b133000be17590ea38e Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 25 Aug 2024 20:26:01 +0800 Subject: [PATCH] update annotation --- braintaichi/_eventop/__init__.py | 1 - braintaichi/_eventop/_event_csrmm.py | 3 +++ braintaichi/_eventop/_event_csrmv.py | 3 ++- braintaichi/_eventop/main.py | 2 ++ braintaichi/_jitconnop/__init__.py | 4 ++-- braintaichi/_jitconnop/_jit_csrmv.py | 5 +++-- braintaichi/_jitconnop/_jit_event_csrmv.py | 2 ++ braintaichi/_jitconnop/_taichi_rand.py | 2 ++ braintaichi/_jitconnop/main.py | 2 ++ braintaichi/_misc.py | 4 +--- braintaichi/_primitive/_ad_support.py | 2 ++ braintaichi/_primitive/_batch_utils.py | 2 ++ braintaichi/_primitive/_mlir_translation_rule.py | 6 ++++-- braintaichi/_primitive/_xla_custom_op.py | 2 ++ braintaichi/_sparseop/_sparse_coomv.py | 2 ++ braintaichi/_sparseop/_sparse_csrmm.py | 2 ++ braintaichi/_sparseop/_sparse_csrmv.py | 3 +++ braintaichi/_sparseop/_sparse_utils.py | 4 ++-- braintaichi/_sparseop/main.py | 3 +++ 19 files changed, 41 insertions(+), 13 deletions(-) diff --git a/braintaichi/_eventop/__init__.py b/braintaichi/_eventop/__init__.py index 1803e83..8f0ec11 100644 --- a/braintaichi/_eventop/__init__.py +++ b/braintaichi/_eventop/__init__.py @@ -20,4 +20,3 @@ __all__ = _main_all del _main_all - diff --git a/braintaichi/_eventop/_event_csrmm.py b/braintaichi/_eventop/_event_csrmm.py index 992261a..2fb3516 100644 --- a/braintaichi/_eventop/_event_csrmm.py +++ b/braintaichi/_eventop/_event_csrmm.py @@ -15,6 +15,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + from typing import Tuple import brainunit as u @@ -30,6 +32,7 @@ from braintaichi._sparseop._sparse_csrmm import raw_csrmm_taichi as normal_csrmm from braintaichi._sparseop._sparse_utils import csr_to_coo + def raw_event_csrmm_taichi( data: jax.typing.ArrayLike | u.Quantity, indices: jax.typing.ArrayLike, diff --git a/braintaichi/_eventop/_event_csrmv.py b/braintaichi/_eventop/_event_csrmv.py index 0978605..6da3d7f 100644 --- a/braintaichi/_eventop/_event_csrmv.py +++ b/braintaichi/_eventop/_event_csrmv.py @@ -25,11 +25,12 @@ """ +from __future__ import annotations + from typing import Union, Tuple import jax import jax.numpy as jnp -import numpy as np import taichi as ti from jax.interpreters import ad diff --git a/braintaichi/_eventop/main.py b/braintaichi/_eventop/main.py index e7bc8a8..2abc0d1 100644 --- a/braintaichi/_eventop/main.py +++ b/braintaichi/_eventop/main.py @@ -14,6 +14,8 @@ # ============================================================================== +from __future__ import annotations + from typing import Union, Tuple import brainunit as u diff --git a/braintaichi/_jitconnop/__init__.py b/braintaichi/_jitconnop/__init__.py index 1157d92..91ccbad 100644 --- a/braintaichi/_jitconnop/__init__.py +++ b/braintaichi/_jitconnop/__init__.py @@ -14,10 +14,10 @@ # ============================================================================== -from .main import * -from .main import __all__ as _main_all from ._taichi_rand import * from ._taichi_rand import __all__ as _taichi_rand_all +from .main import * +from .main import __all__ as _main_all __all__ = _main_all + _taichi_rand_all diff --git a/braintaichi/_jitconnop/_jit_csrmv.py b/braintaichi/_jitconnop/_jit_csrmv.py index 4f847a7..3bceabb 100644 --- a/braintaichi/_jitconnop/_jit_csrmv.py +++ b/braintaichi/_jitconnop/_jit_csrmv.py @@ -15,6 +15,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import numbers from typing import Tuple, Optional @@ -24,8 +26,8 @@ from jax import numpy as jnp from jax.interpreters import ad -from braintaichi._primitive._xla_custom_op import XLACustomOp from braintaichi._misc import _get_dtype, set_module_as +from braintaichi._primitive._xla_custom_op import XLACustomOp from ._taichi_rand import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) __all__ = [ @@ -35,7 +37,6 @@ ] - @set_module_as('braintaichi') def get_homo_weight_matrix( weight: float, diff --git a/braintaichi/_jitconnop/_jit_event_csrmv.py b/braintaichi/_jitconnop/_jit_event_csrmv.py index bf556a1..e81134b 100644 --- a/braintaichi/_jitconnop/_jit_event_csrmv.py +++ b/braintaichi/_jitconnop/_jit_event_csrmv.py @@ -16,6 +16,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + from typing import Tuple import jax diff --git a/braintaichi/_jitconnop/_taichi_rand.py b/braintaichi/_jitconnop/_taichi_rand.py index 2edc796..116f6e6 100644 --- a/braintaichi/_jitconnop/_taichi_rand.py +++ b/braintaichi/_jitconnop/_taichi_rand.py @@ -15,6 +15,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import brainstate as bst import jax import taichi as ti diff --git a/braintaichi/_jitconnop/main.py b/braintaichi/_jitconnop/main.py index c72b7ec..3f24de4 100644 --- a/braintaichi/_jitconnop/main.py +++ b/braintaichi/_jitconnop/main.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + from typing import Tuple, Optional import jax diff --git a/braintaichi/_misc.py b/braintaichi/_misc.py index 8f7bc28..4e9dbc7 100644 --- a/braintaichi/_misc.py +++ b/braintaichi/_misc.py @@ -15,6 +15,7 @@ import jax + def _get_dtype(v): if hasattr(v, 'dtype'): dtype = v.dtype @@ -23,12 +24,9 @@ def _get_dtype(v): return dtype - def set_module_as(name: str): def decorator(f): f.__module__ = name return f return decorator - - diff --git a/braintaichi/_primitive/_ad_support.py b/braintaichi/_primitive/_ad_support.py index 33c27bf..8566739 100644 --- a/braintaichi/_primitive/_ad_support.py +++ b/braintaichi/_primitive/_ad_support.py @@ -15,6 +15,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import functools from functools import partial diff --git a/braintaichi/_primitive/_batch_utils.py b/braintaichi/_primitive/_batch_utils.py index fd36c66..42f9408 100644 --- a/braintaichi/_primitive/_batch_utils.py +++ b/braintaichi/_primitive/_batch_utils.py @@ -15,6 +15,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + from functools import partial import jax.numpy as jnp diff --git a/braintaichi/_primitive/_mlir_translation_rule.py b/braintaichi/_primitive/_mlir_translation_rule.py index 34ffadf..93c3c0e 100644 --- a/braintaichi/_primitive/_mlir_translation_rule.py +++ b/braintaichi/_primitive/_mlir_translation_rule.py @@ -15,6 +15,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import contextlib import hashlib import inspect @@ -37,7 +39,7 @@ # --- REGISTER CUSTOM CALL TARGETS on CPU platforms ### try: - from braintaichi import cpu_ops # noqa + from braintaichi import cpu_ops # noqa for _name, _value in cpu_ops.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="cpu") @@ -46,7 +48,7 @@ # --- REGISTER CUSTOM CALL TARGETS on GPU platforms ### try: - from braintaichi import gpu_ops # noqa + from braintaichi import gpu_ops # noqa for _name, _value in gpu_ops.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="gpu") diff --git a/braintaichi/_primitive/_xla_custom_op.py b/braintaichi/_primitive/_xla_custom_op.py index abc6c97..cee5fe0 100644 --- a/braintaichi/_primitive/_xla_custom_op.py +++ b/braintaichi/_primitive/_xla_custom_op.py @@ -15,6 +15,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + from functools import partial from typing import Callable, Sequence, Tuple, Protocol, Optional, Union diff --git a/braintaichi/_sparseop/_sparse_coomv.py b/braintaichi/_sparseop/_sparse_coomv.py index 7434006..307e933 100644 --- a/braintaichi/_sparseop/_sparse_coomv.py +++ b/braintaichi/_sparseop/_sparse_coomv.py @@ -15,6 +15,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import warnings from functools import partial diff --git a/braintaichi/_sparseop/_sparse_csrmm.py b/braintaichi/_sparseop/_sparse_csrmm.py index 0ebab82..cacafef 100644 --- a/braintaichi/_sparseop/_sparse_csrmm.py +++ b/braintaichi/_sparseop/_sparse_csrmm.py @@ -15,6 +15,8 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + from typing import Tuple import brainunit as u diff --git a/braintaichi/_sparseop/_sparse_csrmv.py b/braintaichi/_sparseop/_sparse_csrmv.py index 598fd62..1a17cfe 100644 --- a/braintaichi/_sparseop/_sparse_csrmv.py +++ b/braintaichi/_sparseop/_sparse_csrmv.py @@ -15,6 +15,9 @@ # -*- coding: utf-8 -*- + +from __future__ import annotations + from typing import Tuple import brainunit as u diff --git a/braintaichi/_sparseop/_sparse_utils.py b/braintaichi/_sparseop/_sparse_utils.py index 8015957..4817b1d 100644 --- a/braintaichi/_sparseop/_sparse_utils.py +++ b/braintaichi/_sparseop/_sparse_utils.py @@ -15,10 +15,11 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import warnings from typing import Tuple -import jax import numpy as np from jax import core, numpy as jnp from jax.interpreters import mlir, ad @@ -33,7 +34,6 @@ ] - def coo_to_csr( pre_ids: jnp.ndarray, post_ids: jnp.ndarray, diff --git a/braintaichi/_sparseop/main.py b/braintaichi/_sparseop/main.py index 71ba01b..660b722 100644 --- a/braintaichi/_sparseop/main.py +++ b/braintaichi/_sparseop/main.py @@ -13,6 +13,9 @@ # limitations under the License. # ============================================================================== + +from __future__ import annotations + from typing import Tuple import brainunit as u