diff --git a/sc2ts/inference.py b/sc2ts/inference.py index b02cdb2..af3abb3 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -1424,7 +1424,17 @@ def trim_branches(ts): return tables.tree_sequence() -def attach_tree(parent_ts, parent_tables, attach_path, reversions, child_ts, date): +def attach_tree( + parent_ts, + parent_tables, + attach_path, + reversions, + child_ts, + date, + epsilon=None, +): + if epsilon is None: + epsilon = 1e-6 # In time units of days ago root_time = min(parent_ts.nodes_time[seg.parent] for seg in attach_path) if root_time == 0: @@ -1444,19 +1454,37 @@ def attach_tree(parent_ts, parent_tables, attach_path, reversions, child_ts, dat child_ts = add_root_edge(child_ts) tree = child_ts.first() + # Add sample node times + current_date = parse_date(date) + node_time = {} # In time units of days ago + for u in tree.postorder(): + if tree.is_sample(u): + node = child_ts.node(u) + sample_date = parse_date(node.metadata['date']) + node_time[u] = (current_date - sample_date).days + assert node_time[u] >= 0.0 + max_sample_time = max(node_time.values()) + node_id_map = {} if child_ts.nodes_time[tree.root] != 1.0: raise ValueError("Time must be scaled from 0 to 1.") - node_time = {} + + num_internal_nodes_visited = 0 for u in tree.postorder()[:-1]: node = child_ts.node(u) - # Tree branch length is scaled from 0 to 1. - time = node.time * root_time - node_time[u] = time + if tree.is_sample(u): + # All sample nodes are terminal + time = node_time[u] + else: + num_internal_nodes_visited += 1 + time = max_sample_time + num_internal_nodes_visited * epsilon + node_time[u] = time metadata = node.metadata if tree.is_internal(u): metadata = {"date_added": date} - new_id = parent_tables.nodes.append(node.replace(time=time, metadata=metadata)) + new_id = parent_tables.nodes.append( + node.replace(time=time, metadata=metadata) + ) node_id_map[node.id] = new_id for v in tree.children(u): parent_tables.edges.add_row(