Skip to content

Commit

Permalink
update codes
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 25, 2024
1 parent f502aee commit f7fd900
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
30 changes: 15 additions & 15 deletions braintaichi/_jitconnop/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@
from ._jit_event_csrmv import raw_event_mv_prob_homo, raw_event_mv_prob_uniform, raw_event_mv_prob_normal

__all__ = [
'mv_prob_homo',
'mv_prob_uniform',
'mv_prob_normal',
'event_mv_prob_homo',
'event_mv_prob_uniform',
'event_mv_prob_normal',
'jitc_mv_prob_homo',
'jitc_mv_prob_uniform',
'jitc_mv_prob_normal',
'jitc_event_mv_prob_homo',
'jitc_event_mv_prob_uniform',
'jitc_event_mv_prob_normal',
]


@set_module_as('braintaichi')
def mv_prob_homo(
def jitc_mv_prob_homo(
vector: jax.typing.ArrayLike,
weight: float,
conn_prob: float,
Expand Down Expand Up @@ -110,7 +110,7 @@ def mv_prob_homo(


@set_module_as('braintaichi')
def mv_prob_uniform(
def jitc_mv_prob_uniform(
vector: jax.Array,
w_low: float,
w_high: float,
Expand Down Expand Up @@ -187,7 +187,7 @@ def mv_prob_uniform(


@set_module_as('braintaichi')
def mv_prob_normal(
def jitc_mv_prob_normal(
vector: jax.Array,
w_mu: float,
w_sigma: float,
Expand Down Expand Up @@ -264,7 +264,7 @@ def mv_prob_normal(


@set_module_as('braintaichi')
def event_mv_prob_homo(
def jitc_event_mv_prob_homo(
events: jax.Array,
weight: float,
conn_prob: float,
Expand All @@ -290,11 +290,11 @@ def event_mv_prob_homo(
outdim_parallel=outdim_parallel)[0]


event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__
jitc_event_mv_prob_homo.__doc__ = jitc_mv_prob_homo.__doc__


@set_module_as('braintaichi')
def event_mv_prob_uniform(
def jitc_event_mv_prob_uniform(
events: jax.Array,
w_low: float,
w_high: float,
Expand All @@ -320,11 +320,11 @@ def event_mv_prob_uniform(
transpose=transpose, outdim_parallel=outdim_parallel)[0]


event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__
jitc_event_mv_prob_uniform.__doc__ = jitc_mv_prob_uniform.__doc__


@set_module_as('braintaichi')
def event_mv_prob_normal(
def jitc_event_mv_prob_normal(
events: jax.Array,
w_mu: float,
w_sigma: float,
Expand All @@ -350,4 +350,4 @@ def event_mv_prob_normal(
transpose=transpose, outdim_parallel=outdim_parallel)[0]


event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__
jitc_event_mv_prob_normal.__doc__ = jitc_mv_prob_normal.__doc__
12 changes: 6 additions & 6 deletions docs/apis/jitconn-operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ JIT Connectivity Operators
:nosignatures:
:template: classtemplate.rst

mv_prob_homo
mv_prob_uniform
mv_prob_normal
event_mv_prob_homo
event_mv_prob_uniform
event_mv_prob_normal
jitc_mv_prob_homo
jitc_mv_prob_uniform
jitc_mv_prob_normal
jitc_event_mv_prob_homo
jitc_event_mv_prob_uniform
jitc_event_mv_prob_normal


2 changes: 1 addition & 1 deletion examples/event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def try_example1():
events = bst.random.random((1000,)) < 0.1

# Create a sparse matrix
r = bti.event_mv_prob_homo(events, 1., conn_prob=0.1, shape=(1000, 1000), seed=123)
r = bti.jitc_event_mv_prob_homo(events, 1., conn_prob=0.1, shape=(1000, 1000), seed=123)
print(r.shape)
print(r)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_example1():
events = bst.random.random((1000,)) < 0.1

# Create a sparse matrix
r = bti.event_mv_prob_homo(events, 1., conn_prob=0.1, shape=(1000, 1000), seed=123)
r = bti.jitc_event_mv_prob_homo(events, 1., conn_prob=0.1, shape=(1000, 1000), seed=123)
print(r.shape)
print(r)

Expand Down

0 comments on commit f7fd900

Please sign in to comment.