From f9fba9e2d6ef95733d8dc841470d3963941e08f5 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 13 Feb 2025 01:23:44 +0000 Subject: [PATCH 1/6] Remove max_num_partitions for sample batch matching --- docs/large_scale.md | 3 +-- tests/test_inference.py | 3 +-- tsinfer/inference.py | 9 +-------- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/docs/large_scale.md b/docs/large_scale.md index 1ae68d71..16f708a6 100644 --- a/docs/large_scale.md +++ b/docs/large_scale.md @@ -195,8 +195,7 @@ of ancestors. There are three API methods that work together to enable distribut 3. {meth}`match_samples_batch_finalise` {meth}`match_samples_batch_init` should be called to set up the batch matching and to determine the -groupings of samples. Similar to {meth}`match_ancestors_batch_init` is has a `min_work_per_job` and -`max_num_partitions` arguments to control the level of parallelism. The method writes a file +groupings of samples. Similar to {meth}`match_ancestors_batch_init` it has a `min_work_per_job` argument to control the level of parallelism. The method writes a file `metadata.json` to the directory `work_dir` that contains a JSON encoded dictionary with configuration for later steps. This is also returned by the call. The `num_partitions` key in this dictionary is the number of times {meth}`match_samples_batch_partition` will need diff --git a/tests/test_inference.py b/tests/test_inference.py index 724b997b..9e39ef32 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1543,8 +1543,8 @@ def test_match_samples_batch(self, tmp_path, tmpdir): ancestral_state="variant_ancestral_allele", ancestor_ts_path=tmpdir / "mat_anc.trees", min_work_per_job=1, - max_num_partitions=10, ) + assert mat_wd.num_partitions == mat_sd.num_samples for i in range(mat_wd.num_partitions): tsinfer.match_samples_batch_partition( work_dir=tmpdir / "working_mat", @@ -1564,7 +1564,6 @@ def test_match_samples_batch(self, tmp_path, tmpdir): ancestral_state="variant_ancestral_allele", ancestor_ts_path=tmpdir / "mask_anc.trees", min_work_per_job=1, - max_num_partitions=10, site_mask="variant_mask_foobar", sample_mask="samples_mask_foobar", ) diff --git a/tsinfer/inference.py b/tsinfer/inference.py index fdcb1b09..40b9c305 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -1186,7 +1186,6 @@ def match_samples_batch_init( ancestor_ts_path, min_work_per_job, *, - max_num_partitions=None, sample_mask=None, site_mask=None, recombination_rate=None, @@ -1206,7 +1205,7 @@ def match_samples_batch_init( ): """ match_samples_batch_init(work_dir, sample_data_path, ancestral_state, - ancestor_ts_path, min_work_per_job, \\*, max_num_partitions=None, + ancestor_ts_path, min_work_per_job, \\*, sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None, path_compression=True, indexes=None, post_process=None, force_sample_times=False) @@ -1237,9 +1236,6 @@ def match_samples_batch_init( genotypes) to allocate to a single parallel job. If the amount of work in a group of samples exceeds this level it will be broken up into parallel partitions, subject to the constraint of `max_num_partitions`. - :param int max_num_partitions: The maximum number of partitions to split a - group of samples into. Useful for limiting the number of jobs in a - workflow to avoid job overhead. Defaults to 1000. :param Union(array, str) sample_mask: A numpy array of booleans specifying which samples to mask out (exclude) from the dataset. Alternatively, a string can be provided, giving the name of an array in the input dataset @@ -1277,9 +1273,6 @@ def match_samples_batch_init( :return: A dictionary of the job metadata, as written to `metadata.json` in `work_dir`. """ - if max_num_partitions is None: - max_num_partitions = 1000 - # Convert working_dir to pathlib.Path work_dir = pathlib.Path(work_dir) From bde205847858059f5b682ec18218518e8416b624 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 13 Feb 2025 14:01:49 +0000 Subject: [PATCH 2/6] Bin pack ancestors in partitions --- tests/test_inference.py | 45 +++++++++++++++++++++++++---- tsinfer/inference.py | 63 +++++++++++++++++++++++------------------ 2 files changed, 75 insertions(+), 33 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 9e39ef32..5f9cd0f7 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1401,16 +1401,16 @@ def test_equivalance_many_at_once(self, tmp_path, tmpdir): tmpdir / "ancestors.zarr", 1000, ) - tsinfer.match_ancestors_batch_groups( - tmpdir / "work", 0, len(metadata["ancestor_grouping"]) // 2, 2 - ) + num_groupings = len(metadata["ancestor_grouping"]) + tsinfer.match_ancestors_batch_groups(tmpdir / "work", 0, num_groupings // 2, 2) tsinfer.match_ancestors_batch_groups( tmpdir / "work", - len(metadata["ancestor_grouping"]) // 2, - len(metadata["ancestor_grouping"]), + num_groupings // 2, + num_groupings, 2, ) - # TODO Check which ones written to disk + assert (tmpdir / "work" / f"ancestors_{(num_groupings//2)-1}.trees").exists() + assert (tmpdir / "work" / f"ancestors_{num_groupings-1}.trees").exists() ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work") ts2 = tsinfer.match_ancestors(samples, ancestors) ts.tables.assert_equals(ts2.tables, ignore_provenance=True) @@ -1438,6 +1438,11 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir): tsinfer.match_ancestors_batch_group_partition( tmpdir / "work", group_index, p_index ) + with pytest.raises(ValueError, match="out of range"): + tsinfer.match_ancestors_batch_group_partition( + tmpdir / "work", group_index, p_index + 1000 + ) + ts = tsinfer.match_ancestors_batch_group_finalise( tmpdir / "work", group_index ) @@ -1523,6 +1528,34 @@ def test_errors(self, tmp_path, tmpdir): with pytest.raises(ValueError, match="sequence length is different"): tsinfer.match_ancestors_batch_groups(tmpdir / "work", 2, 3) + def test_low_min_work_per_job(self, tmp_path, tmpdir): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") + _ = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr")) + metadata = tsinfer.match_ancestors_batch_init( + tmpdir / "work", + zarr_path, + "variant_ancestral_allele", + tmpdir / "ancestors.zarr", + min_work_per_job=1, + max_num_partitions=2, + ) + for group in metadata["ancestor_grouping"]: + assert group["partitions"] is None or len(group["partitions"]) <= 2 + + metadata = tsinfer.match_ancestors_batch_init( + tmpdir / "work2", + zarr_path, + "variant_ancestral_allele", + tmpdir / "ancestors.zarr", + min_work_per_job=1, + max_num_partitions=20000, + ) + for group in metadata["ancestor_grouping"]: + if group["partitions"] is not None: + for partition in group["partitions"]: + assert len(partition) == 1 + @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") class TestBatchSampleMatching: diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 40b9c305..6e179032 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -28,6 +28,7 @@ import json import logging import math +import operator import os import pathlib import pickle @@ -714,34 +715,41 @@ def match_ancestors_batch_init( for group_index, group_ancestors in matcher.group_by_linesweep().items(): # Make ancestor_ids JSON serialisable group_ancestors = list(map(int, group_ancestors)) - partitions = [] - current_partition = [] - current_partition_work = 0 - # TODO: Can do better here by packing ancestors - # into as equal sized partitions as possible + # The first group is trivial so never partition if group_index == 0: - partitions.append(group_ancestors) + partitions = [ + group_ancestors, + ] else: total_work = sum(ancestor_lengths[ancestor] for ancestor in group_ancestors) - min_work_per_job_group = min_work_per_job - if total_work / max_num_partitions > min_work_per_job: - min_work_per_job_group = total_work / max_num_partitions - for ancestor in group_ancestors: - if ( - current_partition_work + ancestor_lengths[ancestor] - > min_work_per_job_group - ): - partitions.append(current_partition) - current_partition = [ancestor] - current_partition_work = ancestor_lengths[ancestor] - else: - current_partition.append(ancestor) - current_partition_work += ancestor_lengths[ancestor] - partitions.append(current_partition) + parition_count = math.ceil(total_work / min_work_per_job) + if parition_count > max_num_partitions: + parition_count = max_num_partitions + + # Partition into roughly equal sized bins (by work) + sorted_ancestors = sorted( + group_ancestors, key=lambda x: ancestor_lengths[x], reverse=True + ) + partitions = [] + partition_lengths = [] + for _ in range(parition_count): + partitions.append([]) + partition_lengths.append(0) + + # Use greedy bin packing - place each ancestor in the bin with + # lowest total length + for ancestor in sorted_ancestors: + min_length_idx = min( + range(len(partition_lengths)), key=lambda i: partition_lengths[i] + ) + partitions[min_length_idx].append(ancestor) + partition_lengths[min_length_idx] += ancestor_lengths[ancestor] + partitions = [sorted(p) for p in partitions if len(p) > 0] + if len(partitions) > 1: group_dir = work_dir / f"group_{group_index}" group_dir.mkdir() - # TODO: Should be a dataclass + group = { "ancestors": group_ancestors, "partitions": partitions if len(partitions) > 1 else None, @@ -902,7 +910,7 @@ def match_ancestors_batch_group_partition(work_dir, group_index, partition_index ) logger.info(f"Dumping to {partition_path}") with open(partition_path, "wb") as f: - pickle.dump((start_time, timing.metrics, results), f) + pickle.dump((start_time, timing.metrics, ancestors_to_match, results), f) def match_ancestors_batch_group_finalise(work_dir, group_index): @@ -935,17 +943,18 @@ def match_ancestors_batch_group_finalise(work_dir, group_index): ) start_times = [] timings = [] - results = [] + results = {} for partition_index in range(len(group["partitions"])): partition_path = os.path.join( work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl" ) with open(partition_path, "rb") as f: - start_time, part_timing, result = pickle.load(f) + start_time, part_timing, ancestors, result = pickle.load(f) start_times.append(start_time) - results.extend(result) + for ancestor, r in zip(ancestors, result): + results[ancestor] = r timings.append(part_timing) - + results = list(map(operator.itemgetter(1), sorted(results.items()))) ts = matcher.finalise_group(group, results, group_index) path = os.path.join(work_dir, f"ancestors_{group_index}.trees") ts.dump(path) From d9f8c71e91c7c7968023f4a232f72099037d51ff Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 14 Feb 2025 11:48:37 +0000 Subject: [PATCH 3/6] Test force_sample_times --- tests/test_inference.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_inference.py b/tests/test_inference.py index 5f9cd0f7..38ee8975 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -35,8 +35,10 @@ import msprime import numpy as np import pytest +import sgkit import tskit import tsutil +import xarray as xr from tskit import MetadataSchema import _tsinfer @@ -1620,6 +1622,42 @@ def test_match_samples_batch(self, tmp_path, tmpdir): mat_ts_batch.tables, ignore_timestamps=True, ignore_provenance=True ) + def test_force_sample_times(self, tmp_path, tmpdir): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + ds = sgkit.load_dataset(zarr_path) + array = [0.0001] * ts.num_individuals + ds.update( + { + "individuals_time": xr.DataArray( + data=array, dims=["sample"], name="individuals_time" + ) + } + ) + sgkit.save_dataset( + ds.drop_vars(set(ds.data_vars) - {"individuals_time"}), zarr_path, mode="a" + ) + samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") + anc = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr")) + anc_ts = tsinfer.match_ancestors(samples, anc) + anc_ts.dump(tmpdir / "anc.trees") + + wd = tsinfer.match_samples_batch_init( + work_dir=tmpdir / "working", + sample_data_path=samples.path, + ancestral_state="variant_ancestral_allele", + ancestor_ts_path=tmpdir / "anc.trees", + min_work_per_job=1e6, + force_sample_times=True, + ) + for i in range(wd.num_partitions): + tsinfer.match_samples_batch_partition( + work_dir=tmpdir / "working", + partition_index=i, + ) + ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working") + ts = tsinfer.match_samples(samples, anc_ts, force_sample_times=True) + ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True) + class TestAncestorGeneratorsEquivalant: """ From 629419b565929acb542fd0f2a40e5a083b3c1640 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 14 Feb 2025 11:56:57 +0000 Subject: [PATCH 4/6] Simplify num_samples_per_partition calc --- tsinfer/inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 6e179032..b138ceb5 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -1331,9 +1331,9 @@ def match_samples_batch_init( sample_times = sample_times.tolist() wd.sample_indexes = sample_indexes wd.sample_times = sample_times - num_samples_per_partition = int(min_work_per_job // variant_data.num_sites) - if num_samples_per_partition == 0: - num_samples_per_partition = 1 + num_samples_per_partition = max( + 1, math.ceil(min_work_per_job // variant_data.num_sites) + ) wd.num_samples_per_partition = num_samples_per_partition wd.num_partitions = math.ceil(len(sample_indexes) / num_samples_per_partition) wd_path = work_dir / "metadata.json" From f0f3ef0fb09d0d8d400f06664bc34e0be3dd18bf Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 14 Feb 2025 12:20:30 +0000 Subject: [PATCH 5/6] Test stored numpy arrays in batch match --- tests/test_inference.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_inference.py b/tests/test_inference.py index 38ee8975..8ad38bc4 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1658,6 +1658,46 @@ def test_force_sample_times(self, tmp_path, tmpdir): ts = tsinfer.match_samples(samples, anc_ts, force_sample_times=True) ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True) + def test_array_args(self, tmp_path, tmpdir): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + sample_mask = np.zeros(ts.num_individuals, dtype=bool) + sample_mask[42] = True + site_mask = np.zeros(ts.num_sites, dtype=bool) + site_mask[42] = True + rng = np.random.RandomState(42) + sites_time = rng.uniform(0, 1, ts.num_sites - 1) + samples = tsinfer.VariantData( + zarr_path, + "variant_ancestral_allele", + sample_mask=sample_mask, + site_mask=site_mask, + sites_time=sites_time, + ) + anc = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr")) + anc_ts = tsinfer.match_ancestors(samples, anc) + anc_ts.dump(tmpdir / "anc.trees") + + wd = tsinfer.match_samples_batch_init( + work_dir=tmpdir / "working", + sample_data_path=samples.path, + sample_mask=sample_mask, + site_mask=site_mask, + ancestral_state="variant_ancestral_allele", + ancestor_ts_path=tmpdir / "anc.trees", + min_work_per_job=1e6, + ) + for i in range(wd.num_partitions): + tsinfer.match_samples_batch_partition( + work_dir=tmpdir / "working", + partition_index=i, + ) + ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working") + ts = tsinfer.match_samples( + samples, + anc_ts, + ) + ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True) + class TestAncestorGeneratorsEquivalant: """ From a063456583b3077304b05695f9a3887914f11fc1 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 14 Feb 2025 13:38:01 +0000 Subject: [PATCH 6/6] Use a heap for ancestor packing --- tsinfer/inference.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/tsinfer/inference.py b/tsinfer/inference.py index b138ceb5..f300d7f6 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -717,35 +717,29 @@ def match_ancestors_batch_init( group_ancestors = list(map(int, group_ancestors)) # The first group is trivial so never partition if group_index == 0: - partitions = [ - group_ancestors, - ] + partitions = [group_ancestors] else: total_work = sum(ancestor_lengths[ancestor] for ancestor in group_ancestors) - parition_count = math.ceil(total_work / min_work_per_job) - if parition_count > max_num_partitions: - parition_count = max_num_partitions + partition_count = math.ceil(total_work / min_work_per_job) + if partition_count > max_num_partitions: + partition_count = max_num_partitions # Partition into roughly equal sized bins (by work) sorted_ancestors = sorted( group_ancestors, key=lambda x: ancestor_lengths[x], reverse=True ) - partitions = [] - partition_lengths = [] - for _ in range(parition_count): - partitions.append([]) - partition_lengths.append(0) # Use greedy bin packing - place each ancestor in the bin with # lowest total length + heap = [(0, []) for _ in range(partition_count)] for ancestor in sorted_ancestors: - min_length_idx = min( - range(len(partition_lengths)), key=lambda i: partition_lengths[i] - ) - partitions[min_length_idx].append(ancestor) - partition_lengths[min_length_idx] += ancestor_lengths[ancestor] - partitions = [sorted(p) for p in partitions if len(p) > 0] - + sum_len, partition = heapq.heappop(heap) + partition.append(ancestor) + sum_len += ancestor_lengths[ancestor] + heapq.heappush(heap, (sum_len, partition)) + partitions = [ + sorted(partition) for sum_len, partition in heap if sum_len > 0 + ] if len(partitions) > 1: group_dir = work_dir / f"group_{group_index}" group_dir.mkdir()