Skip to content

Commit

Permalink
update annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 25, 2024
1 parent df07835 commit f502aee
Show file tree
Hide file tree
Showing 19 changed files with 41 additions and 13 deletions.
1 change: 0 additions & 1 deletion braintaichi/_eventop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@
__all__ = _main_all

del _main_all

3 changes: 3 additions & 0 deletions braintaichi/_eventop/_event_csrmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

from typing import Tuple

import brainunit as u
Expand All @@ -30,6 +32,7 @@
from braintaichi._sparseop._sparse_csrmm import raw_csrmm_taichi as normal_csrmm
from braintaichi._sparseop._sparse_utils import csr_to_coo


def raw_event_csrmm_taichi(
data: jax.typing.ArrayLike | u.Quantity,
indices: jax.typing.ArrayLike,
Expand Down
3 changes: 2 additions & 1 deletion braintaichi/_eventop/_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
"""

from __future__ import annotations

from typing import Union, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import taichi as ti
from jax.interpreters import ad

Expand Down
2 changes: 2 additions & 0 deletions braintaichi/_eventop/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================


from __future__ import annotations

from typing import Union, Tuple

import brainunit as u
Expand Down
4 changes: 2 additions & 2 deletions braintaichi/_jitconnop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# ==============================================================================


from .main import *
from .main import __all__ as _main_all
from ._taichi_rand import *
from ._taichi_rand import __all__ as _taichi_rand_all
from .main import *
from .main import __all__ as _main_all

__all__ = _main_all + _taichi_rand_all

Expand Down
5 changes: 3 additions & 2 deletions braintaichi/_jitconnop/_jit_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import numbers
from typing import Tuple, Optional

Expand All @@ -24,8 +26,8 @@
from jax import numpy as jnp
from jax.interpreters import ad

from braintaichi._primitive._xla_custom_op import XLACustomOp
from braintaichi._misc import _get_dtype, set_module_as
from braintaichi._primitive._xla_custom_op import XLACustomOp
from ._taichi_rand import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal)

__all__ = [
Expand All @@ -35,7 +37,6 @@
]



@set_module_as('braintaichi')
def get_homo_weight_matrix(
weight: float,
Expand Down
2 changes: 2 additions & 0 deletions braintaichi/_jitconnop/_jit_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# -*- coding: utf-8 -*-


from __future__ import annotations

from typing import Tuple

import jax
Expand Down
2 changes: 2 additions & 0 deletions braintaichi/_jitconnop/_taichi_rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import brainstate as bst
import jax
import taichi as ti
Expand Down
2 changes: 2 additions & 0 deletions braintaichi/_jitconnop/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from typing import Tuple, Optional

import jax
Expand Down
4 changes: 1 addition & 3 deletions braintaichi/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import jax


def _get_dtype(v):
if hasattr(v, 'dtype'):
dtype = v.dtype
Expand All @@ -23,12 +24,9 @@ def _get_dtype(v):
return dtype



def set_module_as(name: str):
def decorator(f):
f.__module__ = name
return f

return decorator


2 changes: 2 additions & 0 deletions braintaichi/_primitive/_ad_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import functools
from functools import partial

Expand Down
2 changes: 2 additions & 0 deletions braintaichi/_primitive/_batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

from functools import partial

import jax.numpy as jnp
Expand Down
6 changes: 4 additions & 2 deletions braintaichi/_primitive/_mlir_translation_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import contextlib
import hashlib
import inspect
Expand All @@ -37,7 +39,7 @@

# --- REGISTER CUSTOM CALL TARGETS on CPU platforms ###
try:
from braintaichi import cpu_ops # noqa
from braintaichi import cpu_ops # noqa

for _name, _value in cpu_ops.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")
Expand All @@ -46,7 +48,7 @@

# --- REGISTER CUSTOM CALL TARGETS on GPU platforms ###
try:
from braintaichi import gpu_ops # noqa
from braintaichi import gpu_ops # noqa

for _name, _value in gpu_ops.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
Expand Down
2 changes: 2 additions & 0 deletions braintaichi/_primitive/_xla_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

from functools import partial
from typing import Callable, Sequence, Tuple, Protocol, Optional, Union

Expand Down
2 changes: 2 additions & 0 deletions braintaichi/_sparseop/_sparse_coomv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import warnings
from functools import partial

Expand Down
2 changes: 2 additions & 0 deletions braintaichi/_sparseop/_sparse_csrmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

from typing import Tuple

import brainunit as u
Expand Down
3 changes: 3 additions & 0 deletions braintaichi/_sparseop/_sparse_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

# -*- coding: utf-8 -*-


from __future__ import annotations

from typing import Tuple

import brainunit as u
Expand Down
4 changes: 2 additions & 2 deletions braintaichi/_sparseop/_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import warnings
from typing import Tuple

import jax
import numpy as np
from jax import core, numpy as jnp
from jax.interpreters import mlir, ad
Expand All @@ -33,7 +34,6 @@
]



def coo_to_csr(
pre_ids: jnp.ndarray,
post_ids: jnp.ndarray,
Expand Down
3 changes: 3 additions & 0 deletions braintaichi/_sparseop/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
# ==============================================================================


from __future__ import annotations

from typing import Tuple

import brainunit as u
Expand Down

0 comments on commit f502aee

Please sign in to comment.