Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit linesweep to 1M ancestors #879

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,19 @@ def group_by_linesweep(self):
epoch_end = np.hstack([breaks + 1, [self.num_ancestors]])
time_slices = np.vstack([epoch_start, epoch_end]).T
epoch_sizes = time_slices[:, 1] - time_slices[:, 0]
# Find the epoch where the sum of ancestors has reached 1M as a cutoff
if np.sum(epoch_sizes) <= 1e6:
over_1M_epoch = len(time_slices)
over_1M_epoch_first_ancestor = self.num_ancestors
else:
over_1M_epoch = np.where(np.cumsum(epoch_sizes) > 1e6)[0][0]
over_1M_epoch_first_ancestor = time_slices[over_1M_epoch, 0]
logger.info(
f"1M ancestors reached at {over_1M_epoch} epoch and ancestor "
f"{over_1M_epoch_first_ancestor}"
)

# Find the first epoch with more than a 500 times the median epoch size
median_size = np.median(epoch_sizes)
cutoff = 500 * median_size
# Zero out the first half so that an initial large epoch doesn't
Expand All @@ -1653,13 +1666,17 @@ def group_by_linesweep(self):
# the median epoch size. For a large set of human genomes the median epoch
# size is around 10, so we'll stop grouping by linesweep at 5000.
if np.max(epoch_sizes) <= cutoff:
large_epoch = len(time_slices)
large_epoch_first_ancestor = self.num_ancestors
large_epoch = over_1M_epoch
large_epoch_first_ancestor = over_1M_epoch_first_ancestor
logger.info("No large epochs found, using count cutoff")
else:
large_epoch = np.where(epoch_sizes > cutoff)[0][0]
large_epoch_first_ancestor = time_slices[large_epoch, 0]
logger.info(
f"Large epoch found at {large_epoch} with {epoch_sizes[large_epoch]} "
f"ancestors and ancestor {large_epoch_first_ancestor}"
)
logger.info(f"{len(time_slices)} epochs with {median_size} median size.")
logger.info(f"First large (>{cutoff}) epoch is {large_epoch}")
logger.info(f"Grouping {large_epoch_first_ancestor} ancestors by linesweep")
ancestor_grouping = ancestors.group_ancestors_by_linesweep(
start[:large_epoch_first_ancestor],
Expand All @@ -1672,6 +1689,11 @@ def group_by_linesweep(self):
ancestor_grouping[next_epoch] = np.arange(*time_slices[epoch])
next_epoch += 1

# Assert that every ancestor appears once in ancestor grouping
assert (
len(set(np.hstack(list(ancestor_grouping.values())))) == self.num_ancestors
)

# Remove the "virtual root" ancestor
try:
assert 0 in ancestor_grouping[0]
Expand Down
Loading