From 4896becfda76b9fc7140d689bd3485223c942d04 Mon Sep 17 00:00:00 2001 From: Ao Zhang Date: Tue, 7 May 2024 17:31:56 +0100 Subject: [PATCH] Update numba_njit decorator with parameter wrapper --- phylokit/balance.py | 14 ++++++++------ phylokit/distance.py | 4 ++-- phylokit/inference.py | 2 +- phylokit/jit.py | 19 ++++++++++++++----- phylokit/transform.py | 2 +- phylokit/traversal.py | 4 ++-- phylokit/util.py | 11 ++++------- 7 files changed, 32 insertions(+), 24 deletions(-) diff --git a/phylokit/balance.py b/phylokit/balance.py index 3d3e42b..6fe2a99 100644 --- a/phylokit/balance.py +++ b/phylokit/balance.py @@ -7,7 +7,7 @@ from . import util -@jit.numba_njit +@jit.numba_njit() def _sackin_index(virtual_root, left_child, right_sib): stack = [] root = left_child[virtual_root] @@ -43,7 +43,7 @@ def sackin_index(ds): return _sackin_index(-1, ds.node_left_child.data, ds.node_right_sib.data) -@jit.numba_njit +@jit.numba_njit() def _colless_index(postorder, left_child, right_sib): num_leaves = np.zeros_like(left_child) total = 0.0 @@ -83,11 +83,13 @@ def colless_index(ds): if util.get_num_roots(ds) != 1: raise ValueError("Colless index not defined for multiroot trees") return _colless_index( - ds.traversal_postorder.data, ds.node_left_child.data, ds.node_right_sib.data + ds.traversal_postorder.data, + ds.node_left_child.data, + ds.node_right_sib.data, ) -@jit.numba_njit +@jit.numba_njit() def _b1_index(postorder, left_child, right_sib, parent): max_path_length = np.zeros_like(postorder) total = 0.0 @@ -121,7 +123,7 @@ def b1_index(ds): ) -@jit.numba_njit +@jit.numba_njit() def general_log(x, base): """ Compute the logarithm of x in base `base`. @@ -134,7 +136,7 @@ def general_log(x, base): return math.log(x) / math.log(base) -@jit.numba_njit +@jit.numba_njit() def _b2_index(virtual_root, left_child, right_sib, base): root = left_child[virtual_root] stack = [(root, 1)] diff --git a/phylokit/distance.py b/phylokit/distance.py index a4ccbdd..dfd818d 100644 --- a/phylokit/distance.py +++ b/phylokit/distance.py @@ -5,7 +5,7 @@ from . import util -@jit.numba_njit +@jit.numba_njit() def _mrca(parent, time, u, v): tu = time[u] tv = time[v] @@ -52,7 +52,7 @@ def mrca(ds, u, v): return _mrca(ds.node_parent.data, ds.node_time.data, u, v) -@jit.numba_njit +@jit.numba_njit() def _kc_distance(samples, ds1, ds2): # ds1 and ds2 are tuples of the form (parent_array, time_array, branch_length, root) n = samples.shape[0] diff --git a/phylokit/inference.py b/phylokit/inference.py index 70a25bb..93fa271 100644 --- a/phylokit/inference.py +++ b/phylokit/inference.py @@ -7,7 +7,7 @@ from . import jit -@jit.numba_njit +@jit.numba_njit() def _linkage_matrix_to_dataset(Z): n = Z.shape[0] + 1 N = 2 * n diff --git a/phylokit/jit.py b/phylokit/jit.py index a2fd9b9..a6f61b4 100644 --- a/phylokit/jit.py +++ b/phylokit/jit.py @@ -1,3 +1,4 @@ +import functools import logging import os @@ -30,8 +31,16 @@ } -def numba_njit(func, **kwargs): - if ENABLE_NUMBA: # pragma: no cover - return numba.jit(func, **{**DEFAULT_NUMBA_ARGS, **kwargs}) - else: - return func +def numba_njit(**numba_kwargs): + def _numba_njit(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) # pragma: no cover + + if ENABLE_NUMBA: # pragma: no cover + combined_kwargs = {**DEFAULT_NUMBA_ARGS, **numba_kwargs} + return numba.jit(**combined_kwargs)(func) + else: + return func + + return _numba_njit diff --git a/phylokit/transform.py b/phylokit/transform.py index 5ae8982..8181b7c 100644 --- a/phylokit/transform.py +++ b/phylokit/transform.py @@ -5,7 +5,7 @@ from . import util -@jit.numba_njit +@jit.numba_njit() def _permute_node_seq(nodes, ordering, reversed_map): ret = np.zeros_like(nodes, dtype=np.int32) for u, v in enumerate(ordering): diff --git a/phylokit/traversal.py b/phylokit/traversal.py index 22f47e6..8ef5e58 100644 --- a/phylokit/traversal.py +++ b/phylokit/traversal.py @@ -3,7 +3,7 @@ from . import jit -@jit.numba_njit +@jit.numba_njit() def _postorder(left_child, right_sib, root): # Another implementation with python stack operations such as `pop` # makes the same function about 2X slower. @@ -59,7 +59,7 @@ def postorder(ds, root=None): ) -@jit.numba_njit +@jit.numba_njit() def _preorder(parent, left_child, right_sib, root): # Another implementation with python stack operations such as `pop` # makes the same function about 2X slower. diff --git a/phylokit/util.py b/phylokit/util.py index d367d6c..289d223 100644 --- a/phylokit/util.py +++ b/phylokit/util.py @@ -4,7 +4,7 @@ from . import jit -@jit.numba_njit +@jit.numba_njit() def _is_unary(postorder, left_child, right_sib): for u in postorder: v = left_child[u] @@ -45,7 +45,7 @@ def check_node_bounds(ds, *args): raise ValueError(f"Node {u} is not in the tree") -@jit.numba_njit +@jit.numba_njit() def _get_num_roots(left_child, right_sib): u = left_child[-1] num_roots = 0 @@ -66,7 +66,7 @@ def get_num_roots(ds): return _get_num_roots(ds.node_left_child.data, ds.node_right_sib.data) -@jit.numba_njit +@jit.numba_njit() def _branch_length(parent, time, u): ret = 0 p = parent[u] @@ -89,7 +89,7 @@ def branch_length(ds, u): return _branch_length(ds.node_parent.data, ds.node_time.data, u) -@jit.numba_njit +@jit.numba_njit() def _get_node_branch_length(parent, time): ret = np.zeros_like(parent, dtype=np.float64) for i in range(parent.shape[0]): @@ -132,6 +132,3 @@ def base_mapping(base_matrix, mapper_matrix): result_matrix[i] = j break return result_matrix.reshape(base_shape) - - -base_mapping = jit.numba_njit(base_mapping, parallel=True)