Skip to content

Commit

Permalink
Cache references to edges_parent & edges_child
Browse files Browse the repository at this point in the history
And remove the need to specify which nodes to constrain, which was never used anyway.
  • Loading branch information
hyanwong committed Sep 6, 2022
1 parent 974038d commit 6b523bd
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 43 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tskit>=0.4.0
tskit>=0.5.2
tsinfer>=0.2.0
flake8
numpy
Expand Down
13 changes: 1 addition & 12 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,18 +1530,7 @@ def test_constrain_ages_topo(self):
ts = utility_functions.two_tree_ts()
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
eps = 1e-6
nodes_to_date = np.array([3, 4, 5])
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
assert np.array_equal(
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
)

def test_constrain_ages_topo_no_nodes_to_date(self):
ts = utility_functions.two_tree_ts()
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
eps = 1e-6
nodes_to_date = None
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
constrained_ages = constrain_ages_topo(ts, post_mn, eps)
assert np.array_equal(
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
)
Expand Down
25 changes: 13 additions & 12 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,22 +936,23 @@ def constrain_ages_topo(ts, node_times, eps, progress=False):
If node_times violate topology, return increased node_times so that each node is
guaranteed to be older than any of its their children.
"""
tables = ts.tables
edges_parent = ts.edges_parent
edges_child = ts.edges_child

new_node_times = np.copy(node_times)
# Traverse through the ARG, ensuring children come before parents.
# This can be done by iterating over groups of edges with the same parent
new_parent_edge_idx = np.concatenate(
(
[0],
np.where(np.diff(tables.edges.parent) != 0)[0] + 1,
[tables.edges.num_rows],
)
)
for edges_start, edges_end in zip(
new_parent_edge_idx[:-1], new_parent_edge_idx[1:]
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
for edges_start, edges_end in tqdm(
zip(
itertools.chain([0], new_parent_edge_idx),
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
),
desc="Constrain Ages",
disable=not progress,
):
parent = tables.edges.parent[edges_start]
child_ids = tables.edges.child[edges_start:edges_end] # May contain dups
parent = edges_parent[edges_start]
child_ids = edges_child[edges_start:edges_end] # May contain dups
oldest_child_time = np.max(new_node_times[child_ids])
if oldest_child_time >= new_node_times[parent]:
new_node_times[parent] = oldest_child_time + eps
Expand Down
38 changes: 20 additions & 18 deletions tsdate/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
Routines and classes for creating priors and timeslices for use in tsdate
"""
import itertools
import logging
import os
from collections import defaultdict
Expand Down Expand Up @@ -1030,10 +1031,8 @@ def _truncate_priors(ts, priors, progress=False):
Truncate priors for all nonfixed nodes
so they conform to the age of fixed nodes in the tree sequence
"""
tables = ts.tables

fixed_nodes = priors.fixed_node_ids()
fixed_times = tables.nodes.time[fixed_nodes]
fixed_times = ts.nodes_time[fixed_nodes]

grid_data = np.copy(priors.grid_data[:])
timepoints = priors.timepoints
Expand All @@ -1043,24 +1042,25 @@ def _truncate_priors(ts, priors, progress=False):
zero_value = 0
elif priors.probability_space == "logarithmic":
zero_value = -np.inf
constrained_min_times = np.zeros_like(tables.nodes.time)
constrained_min_times = np.zeros_like(ts.nodes_time)
# Set the min times of fixed nodes to those in the tree sequence
constrained_min_times[fixed_nodes] = fixed_times

# Traverse through the ARG, ensuring children come before parents.
# This can be done by iterating over groups of edges with the same parent
new_parent_edge_idx = np.concatenate(
(
[0],
np.where(np.diff(tables.edges.parent) != 0)[0] + 1,
[tables.edges.num_rows],
)
)
for edges_start, edges_end in zip(
new_parent_edge_idx[:-1], new_parent_edge_idx[1:]
edges_parent = ts.edges_parent
edges_child = ts.edges_child
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
for edges_start, edges_end in tqdm(
zip(
itertools.chain([0], new_parent_edge_idx),
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
),
desc="Trunc priors",
disable=not progress,
):
parent = tables.edges.parent[edges_start]
child_ids = tables.edges.child[edges_start:edges_end] # May contain dups
parent = edges_parent[edges_start]
child_ids = edges_child[edges_start:edges_end] # May contain dups
oldest_child_time = np.max(constrained_min_times[child_ids])
if oldest_child_time > constrained_min_times[parent]:
if priors.is_fixed(parent):
Expand Down Expand Up @@ -1198,15 +1198,17 @@ def build_grid(
node_var_override=node_var_override,
progress=progress,
)
tables = tree_sequence.tables
if np.any(tables.nodes.time[tree_sequence.samples()] > 0):
if np.any(tree_sequence.nodes_time[tree_sequence.samples()] > 0):
if not allow_historical_samples:
raise ValueError(
"There are samples at non-zero times, invalidating the conditional "
"coalescent prior. You can set allow_historical_samples=True to carry "
"on regardless, calculating a prior as if all samples were "
"contemporaneous (reasonable if you only have a few ancient samples)"
)
if np.any(tables.nodes.time[priors.fixed_node_ids()] > 0) and truncate_priors:
if (
np.any(tree_sequence.nodes_time[priors.fixed_node_ids()] > 0)
and truncate_priors
):
priors = _truncate_priors(tree_sequence, priors, progress=progress)
return priors

0 comments on commit 6b523bd

Please sign in to comment.