Skip to content

Commit

Permalink
Do many groups at once to avoid reloading sampledata
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Jun 4, 2024
1 parent 87e2810 commit 5e2f59e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 11 deletions.
31 changes: 29 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,32 @@ def test_equivalance(self, tmp_path, tmpdir):
tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000
)
for group_index, _ in enumerate(metadata["ancestor_grouping"]):
tsinfer.match_ancestors_batch_group(tmpdir / "work", group_index, 2)
tsinfer.match_ancestors_batch_groups(
tmpdir / "work", group_index, group_index + 1, 2
)
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)

def test_equivalance_many_at_once(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
ancestors = tsinfer.generate_ancestors(
samples, path=str(tmpdir / "ancestors.zarr")
)
metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000
)
tsinfer.match_ancestors_batch_groups(
tmpdir / "work", 0, len(metadata["ancestor_grouping"]) // 2, 2
)
tsinfer.match_ancestors_batch_groups(
tmpdir / "work",
len(metadata["ancestor_grouping"]) // 2,
len(metadata["ancestor_grouping"]),
2,
)
# TODO Check which ones written to disk
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)
Expand All @@ -1434,7 +1459,9 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir):
)
for group_index, group in enumerate(metadata["ancestor_grouping"]):
if group["partitions"] is None:
tsinfer.match_ancestors_batch_group(tmpdir / "work", group_index)
tsinfer.match_ancestors_batch_groups(
tmpdir / "work", group_index, group_index + 1
)
else:
for p_index, _ in enumerate(group["partitions"]):
tsinfer.match_ancestors_batch_group_partition(
Expand Down
30 changes: 21 additions & 9 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,14 @@ def match_ancestors_batch_init(
if group_index == 0:
partitions.append(group_ancestors)
else:
total_work = sum(ancestor_lengths[ancestor] for ancestor in group_ancestors)
min_work_per_job_group = min_work_per_job
if total_work / 1000 > min_work_per_job:
min_work_per_job_group = total_work / 1000

Check warning on line 603 in tsinfer/inference.py

View check run for this annotation

Codecov / codecov/patch

tsinfer/inference.py#L603

Added line #L603 was not covered by tests
for ancestor in group_ancestors:
if (
current_partition_work + ancestor_lengths[ancestor]
> min_work_per_job
> min_work_per_job_group
):
partitions.append(current_partition)
current_partition = [ancestor]
Expand Down Expand Up @@ -667,22 +671,30 @@ def initialize_matcher(metadata, ancestors_ts=None, **kwargs):
)


def match_ancestors_batch_group(work_dir, group_index, num_threads=0):
def match_ancestors_batch_groups(
work_dir, group_index_start, group_index_end, num_threads=0
):
metadata_path = os.path.join(work_dir, "metadata.json")
with open(metadata_path) as f:
metadata = json.load(f)
if group_index >= len(metadata["ancestor_grouping"]) or group_index < 0:
raise ValueError(f"Group {group_index} is out of range")
group = metadata["ancestor_grouping"][group_index]
if group_index == 0:
if group_index_start >= len(metadata["ancestor_grouping"]) or group_index_start < 0:
raise ValueError(f"Group {group_index_start} is out of range")

Check warning on line 681 in tsinfer/inference.py

View check run for this annotation

Codecov / codecov/patch

tsinfer/inference.py#L681

Added line #L681 was not covered by tests
if group_index_end > len(metadata["ancestor_grouping"]) or group_index_end < 1:
raise ValueError(f"Group {group_index_end} is out of range")

Check warning on line 683 in tsinfer/inference.py

View check run for this annotation

Codecov / codecov/patch

tsinfer/inference.py#L683

Added line #L683 was not covered by tests
if group_index_start == 0:
ancestors_ts = None
else:
ancestors_ts = tskit.load(
os.path.join(work_dir, f"ancestors_{group_index-1}.trees")
os.path.join(work_dir, f"ancestors_{group_index_start-1}.trees")
)
matcher = initialize_matcher(metadata, ancestors_ts, num_threads=num_threads)
ts = matcher.match_ancestors({group_index: group["ancestors"]})
path = os.path.join(work_dir, f"ancestors_{group_index}.trees")
ts = matcher.match_ancestors(
{
group_index: metadata["ancestor_grouping"][group_index]["ancestors"]
for group_index in range(group_index_start, group_index_end)
}
)
path = os.path.join(work_dir, f"ancestors_{group_index_end-1}.trees")
logger.info(f"Dumping to {path}")
ts.dump(path)
return ts
Expand Down

0 comments on commit 5e2f59e

Please sign in to comment.