Skip to content

Commit

Permalink
Python: simplify merge_x_ancestors
Browse files Browse the repository at this point in the history
Pyhhon: Factor out common logic from merge_ancestors and merge_two ancestors
  • Loading branch information
jeromekelleher committed Aug 1, 2024
1 parent 58c5616 commit eece1e0
Showing 1 changed file with 82 additions and 71 deletions.
153 changes: 82 additions & 71 deletions algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class Segment:
next: Segment = None # noqa: A003
lineage: Lineage = None

def __repr__(self):
def __str__(self):
return repr((self.left, self.right, self.node))

@staticmethod
Expand Down Expand Up @@ -1115,7 +1115,12 @@ def alloc_segment(

def alloc_lineage(self, head, population, *, label=0, tail=None):
lineage = Lineage(head, population=population, label=label, tail=tail)
lineage.reset_segments()
assert tail is None
if head is not None:
# If we're allocating a new lineage for a given head segment, then we
# have no choice but to iterate over the rest of the chain to update
# the lineage reference, and determine the tail.
lineage.reset_segments()
return lineage

def copy_segment(self, segment):
Expand Down Expand Up @@ -1690,6 +1695,7 @@ def dtwf_climb_pedigree(self):
for ploid in range(ind.ploidy):
self.process_pedigree_common_ancestors(ind, ploid)

# TODO change to accept a lineage
def store_arg_edges(self, segment, u=-1):
if u == -1:
u = len(self.tables.nodes) - 1
Expand Down Expand Up @@ -1997,6 +2003,7 @@ def wiuf_gene_conversion_within_event(self, label):
elif head is not None:
new_individual_head = head
if new_individual_head is not None:
# FIXME when doing the smc_k update
lineage.reset_segments()
new_lineage = self.alloc_lineage(new_individual_head, pop)
if self.model == "smc_k":
Expand Down Expand Up @@ -2254,13 +2261,12 @@ def store_additional_nodes_edges(self, flag, new_node_id, z):
return new_node_id

def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
pop = self.P[pop_id]
defrag_required = False
coalescence = False
pass_through = len(H) == 1
alpha = None
z = None
new_lineage = None
new_lineage = self.alloc_lineage(None, pop_id, label=label)

while len(H) > 0:
alpha = None
left = H[0][0]
Expand Down Expand Up @@ -2321,10 +2327,15 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):

# loop tail; update alpha and integrate it into the state.
if alpha is not None:
if z is None:
new_lineage = self.alloc_lineage(alpha, pop_id)
pop.add(new_lineage, label)
alpha.lineage = new_lineage
alpha.prev = new_lineage.tail
self.set_segment_mass(alpha)
if new_lineage.head is None:
new_lineage.head = alpha
assert new_lineage.tail is None
else:
new_lineage.tail.next = alpha
z = new_lineage.tail
if (coalescence and not self.coalescing_segments_only) or (
self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0
):
Expand All @@ -2333,38 +2344,18 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1):
defrag_required |= (
z.right == alpha.left and z.node == alpha.node
)
z.next = alpha
alpha.prev = z
alpha.lineage = new_lineage
self.set_segment_mass(alpha)
z = alpha
if coalescence:
if not self.coalescing_segments_only:
self.store_arg_edges(z, new_node_id)
else:
if not pass_through:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
new_node_id = self.store_additional_nodes_edges(
msprime.NODE_IS_CA_EVENT, new_node_id, z
)
else:
if self.additional_nodes.value & msprime.NODE_IS_PASS_THROUGH > 0:
assert new_node_id != -1
assert self.model == "fixed_pedigree"
new_node_id = self.store_additional_nodes_edges(
msprime.NODE_IS_PASS_THROUGH, new_node_id, z
)

if defrag_required:
self.defrag_segment_chain(z)
if coalescence:
self.defrag_breakpoints()
if new_lineage is not None:
# FIXME do this more efficiently!
new_lineage.reset_segments()
return new_lineage
new_lineage.tail = alpha

return self.insert_merged_lineage(
new_lineage,
new_node_id,
coalescence=coalescence,
pass_through=pass_through,
defrag_required=defrag_required,
)

def defrag_segment_chain(self, z):
def defrag_segment_chain(self, lineage):
z = lineage.tail
y = z
while y.prev is not None:
x = y.prev
Expand All @@ -2374,6 +2365,9 @@ def defrag_segment_chain(self, z):
if y.next is not None:
y.next.prev = x
self.set_segment_mass(x)
if y == lineage.tail:
lineage.tail = x
assert y != lineage.head
self.free_segment(y)
y = x

Expand Down Expand Up @@ -2442,12 +2436,11 @@ def common_ancestor_event(self, population_index, label):
self.merge_two_ancestors(population_index, label, x, y)

def merge_two_ancestors(self, population_index, label, x, y, u=-1):
pop = self.P[population_index]
self.num_ca_events += 1
z = None
new_lineage = None
new_lineage = self.alloc_lineage(None, population_index, label=label)
coalescence = False
defrag_required = False

while x is not None or y is not None:
alpha = None
if x is None or y is None:
Expand Down Expand Up @@ -2476,8 +2469,7 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):
if not coalescence:
coalescence = True
if u == -1:
self.store_node(population_index)
u = len(self.tables.nodes) - 1
u = self.store_node(population_index)
# Put in breakpoints for the outer edges of the coalesced
# segment
left = x.left
Expand All @@ -2501,7 +2493,6 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):
left=left,
right=right,
node=u,
population=population_index,
)
if x.node != u: # required for dtwf and fixed_pedigree
self.store_edge(left, right, u, x.node)
Expand All @@ -2521,11 +2512,15 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):

# loop tail; update alpha and integrate it into the state.
if alpha is not None:
if z is None:
new_lineage = self.alloc_lineage(
alpha, population_index, label=label
)
alpha.lineage = new_lineage
alpha.prev = new_lineage.tail
self.set_segment_mass(alpha)
if new_lineage.head is None:
new_lineage.head = alpha
assert new_lineage.tail is None
else:
new_lineage.tail.next = alpha
z = new_lineage.tail
if (coalescence and not self.coalescing_segments_only) or (
self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0
):
Expand All @@ -2534,42 +2529,57 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1):
defrag_required |= (
z.right == alpha.left and z.node == alpha.node
)
z.next = alpha
alpha.prev = z
alpha.lineage = new_lineage
self.set_segment_mass(alpha)
z = alpha
new_lineage.tail = alpha

return self.insert_merged_lineage(
new_lineage, u, coalescence=coalescence, defrag_required=defrag_required
)

def insert_merged_lineage(
self, new_lineage, u, *, coalescence, defrag_required, pass_through=False
):
z = new_lineage.tail

if coalescence:
if not self.coalescing_segments_only:
self.store_arg_edges(z, u)
else:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
self.store_additional_nodes_edges(msprime.NODE_IS_CA_EVENT, u, z)
if not pass_through:
if self.additional_nodes.value & msprime.NODE_IS_CA_EVENT > 0:
u = self.store_additional_nodes_edges(
msprime.NODE_IS_CA_EVENT, u, z
)
else:
if self.additional_nodes.value & msprime.NODE_IS_PASS_THROUGH > 0:
assert u != -1
assert self.model == "fixed_pedigree"
u = self.store_additional_nodes_edges(
msprime.NODE_IS_PASS_THROUGH, u, z
)

if defrag_required:
self.defrag_segment_chain(z)
self.defrag_segment_chain(new_lineage)
if coalescence:
self.defrag_breakpoints()

if new_lineage is not None:
x = new_lineage.head
# TODO do this more efficiently
while x is not None:
if new_lineage.head is not None:
# Use up any uncoalesced segments at the end of the chain
while (x := new_lineage.tail.next) is not None:
x.lineage = new_lineage
new_lineage.tail = x
x = x.next
self.add_lineage(new_lineage)

if new_lineage is not None and self.model == "smc_k":
merged_head = new_lineage.head
assert merged_head.prev is None
hull = self.alloc_hull(merged_head.left, merged_head.right, new_lineage)
while merged_head is not None:
right = merged_head.right
merged_head = merged_head.next
hull.right = min(right + self.hull_offset, self.L)
pop.add_hull(label, hull)
if self.model == "smc_k":
merged_head = new_lineage.head
assert merged_head.prev is None
hull = self.alloc_hull(merged_head.left, merged_head.right, new_lineage)
while merged_head is not None:
right = merged_head.right
merged_head = merged_head.next
hull.right = min(right + self.hull_offset, self.L)
pop = self.P[new_lineage.population]
pop.add_hull(new_lineage.label, hull)
return new_lineage

def print_state(self, verify=False):
print("State @ time ", self.t)
Expand Down Expand Up @@ -2640,6 +2650,7 @@ def verify_segments(self):
for pop_index, pop in enumerate(self.P):
for label in range(self.num_labels):
for lineage in pop.iter_label(label):
# print("LIN", lineage)
assert isinstance(lineage, Lineage)
assert lineage.label == label
assert lineage.population == pop_index
Expand Down

0 comments on commit eece1e0

Please sign in to comment.