Skip to content

Commit

Permalink
Limit linesweep to 1M ancestors
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Jan 10, 2024
1 parent 7db1d38 commit c78492c
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,19 @@ def group_by_linesweep(self):
epoch_end = np.hstack([breaks + 1, [self.num_ancestors]])
time_slices = np.vstack([epoch_start, epoch_end]).T
epoch_sizes = time_slices[:, 1] - time_slices[:, 0]
# Find the epoch where the sum of ancestors has reached 1M as a cutoff
if np.sum(epoch_sizes) <= 1e6:
over_1M_epoch = len(time_slices)
over_1M_epoch_first_ancestor = self.num_ancestors
else:
over_1M_epoch = np.where(np.cumsum(epoch_sizes) > 1e6)[0][0]
over_1M_epoch_first_ancestor = time_slices[over_1M_epoch, 0]
logger.info(
f"1M ancestors reached at {over_1M_epoch} epoch and ancestor "
f"{over_1M_epoch_first_ancestor}"
)

# Find the first epoch with more than a 500 times the median epoch size
median_size = np.median(epoch_sizes)
cutoff = 500 * median_size
# Zero out the first half so that an initial large epoch doesn't
Expand All @@ -1653,13 +1666,17 @@ def group_by_linesweep(self):
# the median epoch size. For a large set of human genomes the median epoch
# size is around 10, so we'll stop grouping by linesweep at 5000.
if np.max(epoch_sizes) <= cutoff:
large_epoch = len(time_slices)
large_epoch_first_ancestor = self.num_ancestors
large_epoch = over_1M_epoch
large_epoch_first_ancestor = over_1M_epoch_first_ancestor
logger.info("No large epochs found, using count cutoff")
else:
large_epoch = np.where(epoch_sizes > cutoff)[0][0]
large_epoch_first_ancestor = time_slices[large_epoch, 0]
logger.info(
f"Large epoch found at {large_epoch} with {epoch_sizes[large_epoch]} "
f"ancestors and ancestor {large_epoch_first_ancestor}"
)
logger.info(f"{len(time_slices)} epochs with {median_size} median size.")
logger.info(f"First large (>{cutoff}) epoch is {large_epoch}")
logger.info(f"Grouping {large_epoch_first_ancestor} ancestors by linesweep")
ancestor_grouping = ancestors.group_ancestors_by_linesweep(
start[:large_epoch_first_ancestor],
Expand All @@ -1672,6 +1689,11 @@ def group_by_linesweep(self):
ancestor_grouping[next_epoch] = np.arange(*time_slices[epoch])
next_epoch += 1

# Assert that every ancestor appears once in ancestor grouping
assert (
len(set(np.hstack(list(ancestor_grouping.values())))) == self.num_ancestors
)

# Remove the "virtual root" ancestor
try:
assert 0 in ancestor_grouping[0]
Expand Down

0 comments on commit c78492c

Please sign in to comment.