diff --git a/requirements.txt b/requirements.txt index 8dd19e46..163741d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -tskit>=0.4.0 +tskit>=0.5.2 tsinfer>=0.2.0 flake8 numpy diff --git a/tests/test_functions.py b/tests/test_functions.py index 1181bcbf..4c9582d9 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1530,18 +1530,7 @@ def test_constrain_ages_topo(self): ts = utility_functions.two_tree_ts() post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0]) eps = 1e-6 - nodes_to_date = np.array([3, 4, 5]) - constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date) - assert np.array_equal( - np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages - ) - - def test_constrain_ages_topo_no_nodes_to_date(self): - ts = utility_functions.two_tree_ts() - post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0]) - eps = 1e-6 - nodes_to_date = None - constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date) + constrained_ages = constrain_ages_topo(ts, post_mn, eps) assert np.array_equal( np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages ) diff --git a/tsdate/core.py b/tsdate/core.py index f895058b..18f5f653 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -936,22 +936,23 @@ def constrain_ages_topo(ts, node_times, eps, progress=False): If node_times violate topology, return increased node_times so that each node is guaranteed to be older than any of its their children. """ - tables = ts.tables + edges_parent = ts.edges_parent + edges_child = ts.edges_child + new_node_times = np.copy(node_times) # Traverse through the ARG, ensuring children come before parents. # This can be done by iterating over groups of edges with the same parent - new_parent_edge_idx = np.concatenate( - ( - [0], - np.where(np.diff(tables.edges.parent) != 0)[0] + 1, - [tables.edges.num_rows], - ) - ) - for edges_start, edges_end in zip( - new_parent_edge_idx[:-1], new_parent_edge_idx[1:] + new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1 + for edges_start, edges_end in tqdm( + zip( + itertools.chain([0], new_parent_edge_idx), + itertools.chain(new_parent_edge_idx, [len(edges_parent)]), + ), + desc="Constrain Ages", + disable=not progress, ): - parent = tables.edges.parent[edges_start] - child_ids = tables.edges.child[edges_start:edges_end] # May contain dups + parent = edges_parent[edges_start] + child_ids = edges_child[edges_start:edges_end] # May contain dups oldest_child_time = np.max(new_node_times[child_ids]) if oldest_child_time >= new_node_times[parent]: new_node_times[parent] = oldest_child_time + eps diff --git a/tsdate/prior.py b/tsdate/prior.py index b72658c1..9f989692 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -22,6 +22,7 @@ """ Routines and classes for creating priors and timeslices for use in tsdate """ +import itertools import logging import os from collections import defaultdict @@ -1030,10 +1031,8 @@ def _truncate_priors(ts, priors, progress=False): Truncate priors for all nonfixed nodes so they conform to the age of fixed nodes in the tree sequence """ - tables = ts.tables - fixed_nodes = priors.fixed_node_ids() - fixed_times = tables.nodes.time[fixed_nodes] + fixed_times = ts.nodes_time[fixed_nodes] grid_data = np.copy(priors.grid_data[:]) timepoints = priors.timepoints @@ -1043,24 +1042,25 @@ def _truncate_priors(ts, priors, progress=False): zero_value = 0 elif priors.probability_space == "logarithmic": zero_value = -np.inf - constrained_min_times = np.zeros_like(tables.nodes.time) + constrained_min_times = np.zeros_like(ts.nodes_time) # Set the min times of fixed nodes to those in the tree sequence constrained_min_times[fixed_nodes] = fixed_times # Traverse through the ARG, ensuring children come before parents. # This can be done by iterating over groups of edges with the same parent - new_parent_edge_idx = np.concatenate( - ( - [0], - np.where(np.diff(tables.edges.parent) != 0)[0] + 1, - [tables.edges.num_rows], - ) - ) - for edges_start, edges_end in zip( - new_parent_edge_idx[:-1], new_parent_edge_idx[1:] + edges_parent = ts.edges_parent + edges_child = ts.edges_child + new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1 + for edges_start, edges_end in tqdm( + zip( + itertools.chain([0], new_parent_edge_idx), + itertools.chain(new_parent_edge_idx, [len(edges_parent)]), + ), + desc="Trunc priors", + disable=not progress, ): - parent = tables.edges.parent[edges_start] - child_ids = tables.edges.child[edges_start:edges_end] # May contain dups + parent = edges_parent[edges_start] + child_ids = edges_child[edges_start:edges_end] # May contain dups oldest_child_time = np.max(constrained_min_times[child_ids]) if oldest_child_time > constrained_min_times[parent]: if priors.is_fixed(parent): @@ -1198,8 +1198,7 @@ def build_grid( node_var_override=node_var_override, progress=progress, ) - tables = tree_sequence.tables - if np.any(tables.nodes.time[tree_sequence.samples()] > 0): + if np.any(tree_sequence.nodes_time[tree_sequence.samples()] > 0): if not allow_historical_samples: raise ValueError( "There are samples at non-zero times, invalidating the conditional " @@ -1207,6 +1206,9 @@ def build_grid( "on regardless, calculating a prior as if all samples were " "contemporaneous (reasonable if you only have a few ancient samples)" ) - if np.any(tables.nodes.time[priors.fixed_node_ids()] > 0) and truncate_priors: + if ( + np.any(tree_sequence.nodes_time[priors.fixed_node_ids()] > 0) + and truncate_priors + ): priors = _truncate_priors(tree_sequence, priors, progress=progress) return priors