diff --git a/braintaichi/_jitconnop/main.py b/braintaichi/_jitconnop/main.py index 3f24de4..7dc9356 100644 --- a/braintaichi/_jitconnop/main.py +++ b/braintaichi/_jitconnop/main.py @@ -26,17 +26,17 @@ from ._jit_event_csrmv import raw_event_mv_prob_homo, raw_event_mv_prob_uniform, raw_event_mv_prob_normal __all__ = [ - 'mv_prob_homo', - 'mv_prob_uniform', - 'mv_prob_normal', - 'event_mv_prob_homo', - 'event_mv_prob_uniform', - 'event_mv_prob_normal', + 'jitc_mv_prob_homo', + 'jitc_mv_prob_uniform', + 'jitc_mv_prob_normal', + 'jitc_event_mv_prob_homo', + 'jitc_event_mv_prob_uniform', + 'jitc_event_mv_prob_normal', ] @set_module_as('braintaichi') -def mv_prob_homo( +def jitc_mv_prob_homo( vector: jax.typing.ArrayLike, weight: float, conn_prob: float, @@ -110,7 +110,7 @@ def mv_prob_homo( @set_module_as('braintaichi') -def mv_prob_uniform( +def jitc_mv_prob_uniform( vector: jax.Array, w_low: float, w_high: float, @@ -187,7 +187,7 @@ def mv_prob_uniform( @set_module_as('braintaichi') -def mv_prob_normal( +def jitc_mv_prob_normal( vector: jax.Array, w_mu: float, w_sigma: float, @@ -264,7 +264,7 @@ def mv_prob_normal( @set_module_as('braintaichi') -def event_mv_prob_homo( +def jitc_event_mv_prob_homo( events: jax.Array, weight: float, conn_prob: float, @@ -290,11 +290,11 @@ def event_mv_prob_homo( outdim_parallel=outdim_parallel)[0] -event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ +jitc_event_mv_prob_homo.__doc__ = jitc_mv_prob_homo.__doc__ @set_module_as('braintaichi') -def event_mv_prob_uniform( +def jitc_event_mv_prob_uniform( events: jax.Array, w_low: float, w_high: float, @@ -320,11 +320,11 @@ def event_mv_prob_uniform( transpose=transpose, outdim_parallel=outdim_parallel)[0] -event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ +jitc_event_mv_prob_uniform.__doc__ = jitc_mv_prob_uniform.__doc__ @set_module_as('braintaichi') -def event_mv_prob_normal( +def jitc_event_mv_prob_normal( events: jax.Array, w_mu: float, w_sigma: float, @@ -350,4 +350,4 @@ def event_mv_prob_normal( transpose=transpose, outdim_parallel=outdim_parallel)[0] -event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__ +jitc_event_mv_prob_normal.__doc__ = jitc_mv_prob_normal.__doc__ diff --git a/docs/apis/jitconn-operators.rst b/docs/apis/jitconn-operators.rst index c809613..6a73362 100644 --- a/docs/apis/jitconn-operators.rst +++ b/docs/apis/jitconn-operators.rst @@ -10,11 +10,11 @@ JIT Connectivity Operators :nosignatures: :template: classtemplate.rst - mv_prob_homo - mv_prob_uniform - mv_prob_normal - event_mv_prob_homo - event_mv_prob_uniform - event_mv_prob_normal + jitc_mv_prob_homo + jitc_mv_prob_uniform + jitc_mv_prob_normal + jitc_event_mv_prob_homo + jitc_event_mv_prob_uniform + jitc_event_mv_prob_normal diff --git a/examples/event_csrmv.py b/examples/event_csrmv.py index ae12eca..d20799d 100644 --- a/examples/event_csrmv.py +++ b/examples/event_csrmv.py @@ -23,7 +23,7 @@ def try_example1(): events = bst.random.random((1000,)) < 0.1 # Create a sparse matrix - r = bti.event_mv_prob_homo(events, 1., conn_prob=0.1, shape=(1000, 1000), seed=123) + r = bti.jitc_event_mv_prob_homo(events, 1., conn_prob=0.1, shape=(1000, 1000), seed=123) print(r.shape) print(r) diff --git a/tests/test_event_csrmv.py b/tests/test_event_csrmv.py index fd6aac7..b3b825c 100644 --- a/tests/test_event_csrmv.py +++ b/tests/test_event_csrmv.py @@ -22,7 +22,7 @@ def test_example1(): events = bst.random.random((1000,)) < 0.1 # Create a sparse matrix - r = bti.event_mv_prob_homo(events, 1., conn_prob=0.1, shape=(1000, 1000), seed=123) + r = bti.jitc_event_mv_prob_homo(events, 1., conn_prob=0.1, shape=(1000, 1000), seed=123) print(r.shape) print(r)