Skip to content

Commit

Permalink
Update numba_njit decorator with parameter wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Billyzhang1229 committed May 8, 2024
1 parent d0d8d18 commit 4896bec
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 24 deletions.
14 changes: 8 additions & 6 deletions phylokit/balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import util


@jit.numba_njit
@jit.numba_njit()
def _sackin_index(virtual_root, left_child, right_sib):
stack = []
root = left_child[virtual_root]
Expand Down Expand Up @@ -43,7 +43,7 @@ def sackin_index(ds):
return _sackin_index(-1, ds.node_left_child.data, ds.node_right_sib.data)


@jit.numba_njit
@jit.numba_njit()
def _colless_index(postorder, left_child, right_sib):
num_leaves = np.zeros_like(left_child)
total = 0.0
Expand Down Expand Up @@ -83,11 +83,13 @@ def colless_index(ds):
if util.get_num_roots(ds) != 1:
raise ValueError("Colless index not defined for multiroot trees")
return _colless_index(
ds.traversal_postorder.data, ds.node_left_child.data, ds.node_right_sib.data
ds.traversal_postorder.data,
ds.node_left_child.data,
ds.node_right_sib.data,
)


@jit.numba_njit
@jit.numba_njit()
def _b1_index(postorder, left_child, right_sib, parent):
max_path_length = np.zeros_like(postorder)
total = 0.0
Expand Down Expand Up @@ -121,7 +123,7 @@ def b1_index(ds):
)


@jit.numba_njit
@jit.numba_njit()
def general_log(x, base):
"""
Compute the logarithm of x in base `base`.
Expand All @@ -134,7 +136,7 @@ def general_log(x, base):
return math.log(x) / math.log(base)


@jit.numba_njit
@jit.numba_njit()
def _b2_index(virtual_root, left_child, right_sib, base):
root = left_child[virtual_root]
stack = [(root, 1)]
Expand Down
4 changes: 2 additions & 2 deletions phylokit/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from . import util


@jit.numba_njit
@jit.numba_njit()
def _mrca(parent, time, u, v):
tu = time[u]
tv = time[v]
Expand Down Expand Up @@ -52,7 +52,7 @@ def mrca(ds, u, v):
return _mrca(ds.node_parent.data, ds.node_time.data, u, v)


@jit.numba_njit
@jit.numba_njit()
def _kc_distance(samples, ds1, ds2):
# ds1 and ds2 are tuples of the form (parent_array, time_array, branch_length, root)
n = samples.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion phylokit/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import jit


@jit.numba_njit
@jit.numba_njit()
def _linkage_matrix_to_dataset(Z):
n = Z.shape[0] + 1
N = 2 * n
Expand Down
19 changes: 14 additions & 5 deletions phylokit/jit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
import os

Expand Down Expand Up @@ -30,8 +31,16 @@
}


def numba_njit(func, **kwargs):
if ENABLE_NUMBA: # pragma: no cover
return numba.jit(func, **{**DEFAULT_NUMBA_ARGS, **kwargs})
else:
return func
def numba_njit(**numba_kwargs):
def _numba_njit(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs) # pragma: no cover

if ENABLE_NUMBA: # pragma: no cover
combined_kwargs = {**DEFAULT_NUMBA_ARGS, **numba_kwargs}
return numba.jit(**combined_kwargs)(func)
else:
return func

return _numba_njit
2 changes: 1 addition & 1 deletion phylokit/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from . import util


@jit.numba_njit
@jit.numba_njit()
def _permute_node_seq(nodes, ordering, reversed_map):
ret = np.zeros_like(nodes, dtype=np.int32)
for u, v in enumerate(ordering):
Expand Down
4 changes: 2 additions & 2 deletions phylokit/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from . import jit


@jit.numba_njit
@jit.numba_njit()
def _postorder(left_child, right_sib, root):
# Another implementation with python stack operations such as `pop`
# makes the same function about 2X slower.
Expand Down Expand Up @@ -59,7 +59,7 @@ def postorder(ds, root=None):
)


@jit.numba_njit
@jit.numba_njit()
def _preorder(parent, left_child, right_sib, root):
# Another implementation with python stack operations such as `pop`
# makes the same function about 2X slower.
Expand Down
11 changes: 4 additions & 7 deletions phylokit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from . import jit


@jit.numba_njit
@jit.numba_njit()
def _is_unary(postorder, left_child, right_sib):
for u in postorder:
v = left_child[u]
Expand Down Expand Up @@ -45,7 +45,7 @@ def check_node_bounds(ds, *args):
raise ValueError(f"Node {u} is not in the tree")


@jit.numba_njit
@jit.numba_njit()
def _get_num_roots(left_child, right_sib):
u = left_child[-1]
num_roots = 0
Expand All @@ -66,7 +66,7 @@ def get_num_roots(ds):
return _get_num_roots(ds.node_left_child.data, ds.node_right_sib.data)


@jit.numba_njit
@jit.numba_njit()
def _branch_length(parent, time, u):
ret = 0
p = parent[u]
Expand All @@ -89,7 +89,7 @@ def branch_length(ds, u):
return _branch_length(ds.node_parent.data, ds.node_time.data, u)


@jit.numba_njit
@jit.numba_njit()
def _get_node_branch_length(parent, time):
ret = np.zeros_like(parent, dtype=np.float64)
for i in range(parent.shape[0]):
Expand Down Expand Up @@ -132,6 +132,3 @@ def base_mapping(base_matrix, mapper_matrix):
result_matrix[i] = j
break
return result_matrix.reshape(base_shape)


base_mapping = jit.numba_njit(base_mapping, parallel=True)

0 comments on commit 4896bec

Please sign in to comment.