Skip to content

Commit

Permalink
update codes
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 25, 2024
1 parent f7fd900 commit d49a814
Show file tree
Hide file tree
Showing 16 changed files with 12 additions and 45 deletions.
6 changes: 2 additions & 4 deletions braintaichi/_eventop/_event_csrmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

from typing import Tuple
from typing import Tuple, Union

import brainunit as u
import jax
Expand All @@ -34,7 +32,7 @@


def raw_event_csrmm_taichi(
data: jax.typing.ArrayLike | u.Quantity,
data: Union[jax.typing.ArrayLike, u.Quantity],
indices: jax.typing.ArrayLike,
indptr: jax.typing.ArrayLike,
matrix: jax.typing.ArrayLike,
Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_eventop/_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
"""

from __future__ import annotations

from typing import Union, Tuple

import jax
Expand Down
6 changes: 2 additions & 4 deletions braintaichi/_eventop/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# ==============================================================================


from __future__ import annotations

from typing import Union, Tuple

import brainunit as u
Expand All @@ -33,7 +31,7 @@


def event_csrmm(
data: jax.typing.ArrayLike | u.Quantity,
data: Union[jax.typing.ArrayLike, u.Quantity],
indices: jax.typing.ArrayLike,
indptr: jax.typing.ArrayLike,
matrix: jax.typing.ArrayLike,
Expand Down Expand Up @@ -61,7 +59,7 @@ def event_csrmm(


def event_csrmv(
data: Union[float, jax.Array],
data: Union[jax.typing.ArrayLike, u.Quantity],
indices: jax.Array,
indptr: jax.Array,
events: jax.Array,
Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_jitconnop/_jit_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import numbers
from typing import Tuple, Optional

Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_jitconnop/_jit_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# -*- coding: utf-8 -*-


from __future__ import annotations

from typing import Tuple

import jax
Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_jitconnop/_taichi_rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import brainstate as bst
import jax
import taichi as ti
Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_jitconnop/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from typing import Tuple, Optional

import jax
Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_primitive/_ad_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import functools
from functools import partial

Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_primitive/_batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

from functools import partial

import jax.numpy as jnp
Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_primitive/_mlir_translation_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import contextlib
import hashlib
import inspect
Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_primitive/_xla_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

from functools import partial
from typing import Callable, Sequence, Tuple, Protocol, Optional, Union

Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_sparseop/_sparse_coomv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import warnings
from functools import partial

Expand Down
6 changes: 2 additions & 4 deletions braintaichi/_sparseop/_sparse_csrmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

from typing import Tuple
from typing import Tuple, Union

import brainunit as u
import jax
Expand All @@ -32,7 +30,7 @@


def raw_csrmm_taichi(
data: jax.typing.ArrayLike | u.Quantity,
data: Union[jax.typing.ArrayLike, u.Quantity],
indices: jax.typing.ArrayLike,
indptr: jax.typing.ArrayLike,
matrix: jax.typing.ArrayLike,
Expand Down
6 changes: 2 additions & 4 deletions braintaichi/_sparseop/_sparse_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
# -*- coding: utf-8 -*-


from __future__ import annotations

from typing import Tuple
from typing import Tuple, Union

import brainunit as u
import jax
Expand All @@ -33,7 +31,7 @@


def raw_csrmv_taichi(
data: jax.typing.ArrayLike | u.Quantity,
data: Union[jax.typing.ArrayLike, u.Quantity],
indices: jax.typing.ArrayLike,
indptr: jax.typing.ArrayLike,
vector: jax.typing.ArrayLike,
Expand Down
2 changes: 0 additions & 2 deletions braintaichi/_sparseop/_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# -*- coding: utf-8 -*-

from __future__ import annotations

import warnings
from typing import Tuple

Expand Down
11 changes: 4 additions & 7 deletions braintaichi/_sparseop/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
# limitations under the License.
# ==============================================================================


from __future__ import annotations

from typing import Tuple
from typing import Tuple, Union

import brainunit as u
import jax
Expand All @@ -36,7 +33,7 @@

@set_module_as('braintaichi')
def coomv(
data: jax.typing.ArrayLike | u.Quantity,
data: Union[jax.typing.ArrayLike, u.Quantity],
row: jax.typing.ArrayLike,
col: jax.typing.ArrayLike,
vector: jax.typing.ArrayLike,
Expand Down Expand Up @@ -110,7 +107,7 @@ def coomv(

@set_module_as('braintaichi')
def csrmm(
data: jax.typing.ArrayLike | u.Quantity,
data: Union[jax.typing.ArrayLike, u.Quantity],
indices: jax.typing.ArrayLike,
indptr: jax.typing.ArrayLike,
matrix: jax.typing.ArrayLike,
Expand Down Expand Up @@ -140,7 +137,7 @@ def csrmm(

@set_module_as('braintaichi')
def csrmv(
data: jax.typing.ArrayLike | u.Quantity,
data: Union[jax.typing.ArrayLike, u.Quantity],
indices: jax.typing.ArrayLike,
indptr: jax.typing.ArrayLike,
vector: jax.typing.ArrayLike,
Expand Down

0 comments on commit d49a814

Please sign in to comment.