Skip to content

Commit

Permalink
Merge pull request jeromekelleher#202 from szhan/fix_added_back_sampl…
Browse files Browse the repository at this point in the history
…e_times

Put added-back samples at the correct times
  • Loading branch information
jeromekelleher authored Jul 29, 2024
2 parents ab7c55c + 2fdab71 commit 7c110de
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,17 @@ def trim_branches(ts):
return tables.tree_sequence()


def attach_tree(parent_ts, parent_tables, attach_path, reversions, child_ts, date):
def attach_tree(
parent_ts,
parent_tables,
attach_path,
reversions,
child_ts,
date,
epsilon=None,
):
if epsilon is None:
epsilon = 1e-6 # In time units of days ago

root_time = min(parent_ts.nodes_time[seg.parent] for seg in attach_path)
if root_time == 0:
Expand All @@ -1444,19 +1454,37 @@ def attach_tree(parent_ts, parent_tables, attach_path, reversions, child_ts, dat
child_ts = add_root_edge(child_ts)
tree = child_ts.first()

# Add sample node times
current_date = parse_date(date)
node_time = {} # In time units of days ago
for u in tree.postorder():
if tree.is_sample(u):
node = child_ts.node(u)
sample_date = parse_date(node.metadata['date'])
node_time[u] = (current_date - sample_date).days
assert node_time[u] >= 0.0
max_sample_time = max(node_time.values())

node_id_map = {}
if child_ts.nodes_time[tree.root] != 1.0:
raise ValueError("Time must be scaled from 0 to 1.")
node_time = {}

num_internal_nodes_visited = 0
for u in tree.postorder()[:-1]:
node = child_ts.node(u)
# Tree branch length is scaled from 0 to 1.
time = node.time * root_time
node_time[u] = time
if tree.is_sample(u):
# All sample nodes are terminal
time = node_time[u]
else:
num_internal_nodes_visited += 1
time = max_sample_time + num_internal_nodes_visited * epsilon
node_time[u] = time
metadata = node.metadata
if tree.is_internal(u):
metadata = {"date_added": date}
new_id = parent_tables.nodes.append(node.replace(time=time, metadata=metadata))
new_id = parent_tables.nodes.append(
node.replace(time=time, metadata=metadata)
)
node_id_map[node.id] = new_id
for v in tree.children(u):
parent_tables.edges.add_row(
Expand Down

0 comments on commit 7c110de

Please sign in to comment.