From 5e2f59e3191b8455b1033e72c376c45621e416a0 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 4 Jun 2024 10:37:16 +0100 Subject: [PATCH] Do many groups at once to avoid reloading sampledata --- tests/test_inference.py | 31 +++++++++++++++++++++++++++++-- tsinfer/inference.py | 30 +++++++++++++++++++++--------- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 0252fd1a..e5bab983 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1418,7 +1418,32 @@ def test_equivalance(self, tmp_path, tmpdir): tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000 ) for group_index, _ in enumerate(metadata["ancestor_grouping"]): - tsinfer.match_ancestors_batch_group(tmpdir / "work", group_index, 2) + tsinfer.match_ancestors_batch_groups( + tmpdir / "work", group_index, group_index + 1, 2 + ) + ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work") + ts2 = tsinfer.match_ancestors(samples, ancestors) + ts.tables.assert_equals(ts2.tables, ignore_provenance=True) + + def test_equivalance_many_at_once(self, tmp_path, tmpdir): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + samples = tsinfer.SgkitSampleData(zarr_path) + ancestors = tsinfer.generate_ancestors( + samples, path=str(tmpdir / "ancestors.zarr") + ) + metadata = tsinfer.match_ancestors_batch_init( + tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000 + ) + tsinfer.match_ancestors_batch_groups( + tmpdir / "work", 0, len(metadata["ancestor_grouping"]) // 2, 2 + ) + tsinfer.match_ancestors_batch_groups( + tmpdir / "work", + len(metadata["ancestor_grouping"]) // 2, + len(metadata["ancestor_grouping"]), + 2, + ) + # TODO Check which ones written to disk ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work") ts2 = tsinfer.match_ancestors(samples, ancestors) ts.tables.assert_equals(ts2.tables, ignore_provenance=True) @@ -1434,7 +1459,9 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir): ) for group_index, group in enumerate(metadata["ancestor_grouping"]): if group["partitions"] is None: - tsinfer.match_ancestors_batch_group(tmpdir / "work", group_index) + tsinfer.match_ancestors_batch_groups( + tmpdir / "work", group_index, group_index + 1 + ) else: for p_index, _ in enumerate(group["partitions"]): tsinfer.match_ancestors_batch_group_partition( diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 1d620856..1c78ec9c 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -597,10 +597,14 @@ def match_ancestors_batch_init( if group_index == 0: partitions.append(group_ancestors) else: + total_work = sum(ancestor_lengths[ancestor] for ancestor in group_ancestors) + min_work_per_job_group = min_work_per_job + if total_work / 1000 > min_work_per_job: + min_work_per_job_group = total_work / 1000 for ancestor in group_ancestors: if ( current_partition_work + ancestor_lengths[ancestor] - > min_work_per_job + > min_work_per_job_group ): partitions.append(current_partition) current_partition = [ancestor] @@ -667,22 +671,30 @@ def initialize_matcher(metadata, ancestors_ts=None, **kwargs): ) -def match_ancestors_batch_group(work_dir, group_index, num_threads=0): +def match_ancestors_batch_groups( + work_dir, group_index_start, group_index_end, num_threads=0 +): metadata_path = os.path.join(work_dir, "metadata.json") with open(metadata_path) as f: metadata = json.load(f) - if group_index >= len(metadata["ancestor_grouping"]) or group_index < 0: - raise ValueError(f"Group {group_index} is out of range") - group = metadata["ancestor_grouping"][group_index] - if group_index == 0: + if group_index_start >= len(metadata["ancestor_grouping"]) or group_index_start < 0: + raise ValueError(f"Group {group_index_start} is out of range") + if group_index_end > len(metadata["ancestor_grouping"]) or group_index_end < 1: + raise ValueError(f"Group {group_index_end} is out of range") + if group_index_start == 0: ancestors_ts = None else: ancestors_ts = tskit.load( - os.path.join(work_dir, f"ancestors_{group_index-1}.trees") + os.path.join(work_dir, f"ancestors_{group_index_start-1}.trees") ) matcher = initialize_matcher(metadata, ancestors_ts, num_threads=num_threads) - ts = matcher.match_ancestors({group_index: group["ancestors"]}) - path = os.path.join(work_dir, f"ancestors_{group_index}.trees") + ts = matcher.match_ancestors( + { + group_index: metadata["ancestor_grouping"][group_index]["ancestors"] + for group_index in range(group_index_start, group_index_end) + } + ) + path = os.path.join(work_dir, f"ancestors_{group_index_end-1}.trees") logger.info(f"Dumping to {path}") ts.dump(path) return ts