Skip to content

Commit

Permalink
Merge pull request #998 from benjeffery/batch-coverage
Browse files Browse the repository at this point in the history
Improve batch-match coverage
  • Loading branch information
benjeffery authored Feb 17, 2025
2 parents d2048c3 + a063456 commit 1aa0233
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 48 deletions.
3 changes: 1 addition & 2 deletions docs/large_scale.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ of ancestors. There are three API methods that work together to enable distribut
3. {meth}`match_samples_batch_finalise`

{meth}`match_samples_batch_init` should be called to set up the batch matching and to determine the
groupings of samples. Similar to {meth}`match_ancestors_batch_init` is has a `min_work_per_job` and
`max_num_partitions` arguments to control the level of parallelism. The method writes a file
groupings of samples. Similar to {meth}`match_ancestors_batch_init` it has a `min_work_per_job` argument to control the level of parallelism. The method writes a file
`metadata.json` to the directory `work_dir` that contains a JSON encoded dictionary with
configuration for later steps. This is also returned by the call. The `num_partitions` key in
this dictionary is the number of times {meth}`match_samples_batch_partition` will need
Expand Down
126 changes: 118 additions & 8 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
import msprime
import numpy as np
import pytest
import sgkit
import tskit
import tsutil
import xarray as xr
from tskit import MetadataSchema

import _tsinfer
Expand Down Expand Up @@ -1401,16 +1403,16 @@ def test_equivalance_many_at_once(self, tmp_path, tmpdir):
tmpdir / "ancestors.zarr",
1000,
)
tsinfer.match_ancestors_batch_groups(
tmpdir / "work", 0, len(metadata["ancestor_grouping"]) // 2, 2
)
num_groupings = len(metadata["ancestor_grouping"])
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 0, num_groupings // 2, 2)
tsinfer.match_ancestors_batch_groups(
tmpdir / "work",
len(metadata["ancestor_grouping"]) // 2,
len(metadata["ancestor_grouping"]),
num_groupings // 2,
num_groupings,
2,
)
# TODO Check which ones written to disk
assert (tmpdir / "work" / f"ancestors_{(num_groupings//2)-1}.trees").exists()
assert (tmpdir / "work" / f"ancestors_{num_groupings-1}.trees").exists()
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)
Expand Down Expand Up @@ -1438,6 +1440,11 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir):
tsinfer.match_ancestors_batch_group_partition(
tmpdir / "work", group_index, p_index
)
with pytest.raises(ValueError, match="out of range"):
tsinfer.match_ancestors_batch_group_partition(
tmpdir / "work", group_index, p_index + 1000
)

ts = tsinfer.match_ancestors_batch_group_finalise(
tmpdir / "work", group_index
)
Expand Down Expand Up @@ -1523,6 +1530,34 @@ def test_errors(self, tmp_path, tmpdir):
with pytest.raises(ValueError, match="sequence length is different"):
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 2, 3)

def test_low_min_work_per_job(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
_ = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work",
zarr_path,
"variant_ancestral_allele",
tmpdir / "ancestors.zarr",
min_work_per_job=1,
max_num_partitions=2,
)
for group in metadata["ancestor_grouping"]:
assert group["partitions"] is None or len(group["partitions"]) <= 2

metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work2",
zarr_path,
"variant_ancestral_allele",
tmpdir / "ancestors.zarr",
min_work_per_job=1,
max_num_partitions=20000,
)
for group in metadata["ancestor_grouping"]:
if group["partitions"] is not None:
for partition in group["partitions"]:
assert len(partition) == 1


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
class TestBatchSampleMatching:
Expand All @@ -1543,8 +1578,8 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "mat_anc.trees",
min_work_per_job=1,
max_num_partitions=10,
)
assert mat_wd.num_partitions == mat_sd.num_samples
for i in range(mat_wd.num_partitions):
tsinfer.match_samples_batch_partition(
work_dir=tmpdir / "working_mat",
Expand All @@ -1564,7 +1599,6 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "mask_anc.trees",
min_work_per_job=1,
max_num_partitions=10,
site_mask="variant_mask_foobar",
sample_mask="samples_mask_foobar",
)
Expand All @@ -1588,6 +1622,82 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
mat_ts_batch.tables, ignore_timestamps=True, ignore_provenance=True
)

def test_force_sample_times(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
array = [0.0001] * ts.num_individuals
ds.update(
{
"individuals_time": xr.DataArray(
data=array, dims=["sample"], name="individuals_time"
)
}
)
sgkit.save_dataset(
ds.drop_vars(set(ds.data_vars) - {"individuals_time"}), zarr_path, mode="a"
)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
anc = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
anc_ts = tsinfer.match_ancestors(samples, anc)
anc_ts.dump(tmpdir / "anc.trees")

wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working",
sample_data_path=samples.path,
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "anc.trees",
min_work_per_job=1e6,
force_sample_times=True,
)
for i in range(wd.num_partitions):
tsinfer.match_samples_batch_partition(
work_dir=tmpdir / "working",
partition_index=i,
)
ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working")
ts = tsinfer.match_samples(samples, anc_ts, force_sample_times=True)
ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True)

def test_array_args(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
sample_mask = np.zeros(ts.num_individuals, dtype=bool)
sample_mask[42] = True
site_mask = np.zeros(ts.num_sites, dtype=bool)
site_mask[42] = True
rng = np.random.RandomState(42)
sites_time = rng.uniform(0, 1, ts.num_sites - 1)
samples = tsinfer.VariantData(
zarr_path,
"variant_ancestral_allele",
sample_mask=sample_mask,
site_mask=site_mask,
sites_time=sites_time,
)
anc = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
anc_ts = tsinfer.match_ancestors(samples, anc)
anc_ts.dump(tmpdir / "anc.trees")

wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working",
sample_data_path=samples.path,
sample_mask=sample_mask,
site_mask=site_mask,
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "anc.trees",
min_work_per_job=1e6,
)
for i in range(wd.num_partitions):
tsinfer.match_samples_batch_partition(
work_dir=tmpdir / "working",
partition_index=i,
)
ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working")
ts = tsinfer.match_samples(
samples,
anc_ts,
)
ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True)


class TestAncestorGeneratorsEquivalant:
"""
Expand Down
72 changes: 34 additions & 38 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import json
import logging
import math
import operator
import os
import pathlib
import pickle
Expand Down Expand Up @@ -714,34 +715,35 @@ def match_ancestors_batch_init(
for group_index, group_ancestors in matcher.group_by_linesweep().items():
# Make ancestor_ids JSON serialisable
group_ancestors = list(map(int, group_ancestors))
partitions = []
current_partition = []
current_partition_work = 0
# TODO: Can do better here by packing ancestors
# into as equal sized partitions as possible
# The first group is trivial so never partition
if group_index == 0:
partitions.append(group_ancestors)
partitions = [group_ancestors]
else:
total_work = sum(ancestor_lengths[ancestor] for ancestor in group_ancestors)
min_work_per_job_group = min_work_per_job
if total_work / max_num_partitions > min_work_per_job:
min_work_per_job_group = total_work / max_num_partitions
for ancestor in group_ancestors:
if (
current_partition_work + ancestor_lengths[ancestor]
> min_work_per_job_group
):
partitions.append(current_partition)
current_partition = [ancestor]
current_partition_work = ancestor_lengths[ancestor]
else:
current_partition.append(ancestor)
current_partition_work += ancestor_lengths[ancestor]
partitions.append(current_partition)
partition_count = math.ceil(total_work / min_work_per_job)
if partition_count > max_num_partitions:
partition_count = max_num_partitions

# Partition into roughly equal sized bins (by work)
sorted_ancestors = sorted(
group_ancestors, key=lambda x: ancestor_lengths[x], reverse=True
)

# Use greedy bin packing - place each ancestor in the bin with
# lowest total length
heap = [(0, []) for _ in range(partition_count)]
for ancestor in sorted_ancestors:
sum_len, partition = heapq.heappop(heap)
partition.append(ancestor)
sum_len += ancestor_lengths[ancestor]
heapq.heappush(heap, (sum_len, partition))
partitions = [
sorted(partition) for sum_len, partition in heap if sum_len > 0
]
if len(partitions) > 1:
group_dir = work_dir / f"group_{group_index}"
group_dir.mkdir()
# TODO: Should be a dataclass

group = {
"ancestors": group_ancestors,
"partitions": partitions if len(partitions) > 1 else None,
Expand Down Expand Up @@ -902,7 +904,7 @@ def match_ancestors_batch_group_partition(work_dir, group_index, partition_index
)
logger.info(f"Dumping to {partition_path}")
with open(partition_path, "wb") as f:
pickle.dump((start_time, timing.metrics, results), f)
pickle.dump((start_time, timing.metrics, ancestors_to_match, results), f)


def match_ancestors_batch_group_finalise(work_dir, group_index):
Expand Down Expand Up @@ -935,17 +937,18 @@ def match_ancestors_batch_group_finalise(work_dir, group_index):
)
start_times = []
timings = []
results = []
results = {}
for partition_index in range(len(group["partitions"])):
partition_path = os.path.join(
work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl"
)
with open(partition_path, "rb") as f:
start_time, part_timing, result = pickle.load(f)
start_time, part_timing, ancestors, result = pickle.load(f)
start_times.append(start_time)
results.extend(result)
for ancestor, r in zip(ancestors, result):
results[ancestor] = r
timings.append(part_timing)

results = list(map(operator.itemgetter(1), sorted(results.items())))
ts = matcher.finalise_group(group, results, group_index)
path = os.path.join(work_dir, f"ancestors_{group_index}.trees")
ts.dump(path)
Expand Down Expand Up @@ -1186,7 +1189,6 @@ def match_samples_batch_init(
ancestor_ts_path,
min_work_per_job,
*,
max_num_partitions=None,
sample_mask=None,
site_mask=None,
recombination_rate=None,
Expand All @@ -1206,7 +1208,7 @@ def match_samples_batch_init(
):
"""
match_samples_batch_init(work_dir, sample_data_path, ancestral_state,
ancestor_ts_path, min_work_per_job, \\*, max_num_partitions=None,
ancestor_ts_path, min_work_per_job, \\*,
sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None,
path_compression=True, indexes=None, post_process=None, force_sample_times=False)
Expand Down Expand Up @@ -1237,9 +1239,6 @@ def match_samples_batch_init(
genotypes) to allocate to a single parallel job. If the amount of work in
a group of samples exceeds this level it will be broken up into parallel
partitions, subject to the constraint of `max_num_partitions`.
:param int max_num_partitions: The maximum number of partitions to split a
group of samples into. Useful for limiting the number of jobs in a
workflow to avoid job overhead. Defaults to 1000.
:param Union(array, str) sample_mask: A numpy array of booleans specifying
which samples to mask out (exclude) from the dataset. Alternatively, a
string can be provided, giving the name of an array in the input dataset
Expand Down Expand Up @@ -1277,9 +1276,6 @@ def match_samples_batch_init(
:return: A dictionary of the job metadata, as written to `metadata.json` in
`work_dir`.
"""
if max_num_partitions is None:
max_num_partitions = 1000

# Convert working_dir to pathlib.Path
work_dir = pathlib.Path(work_dir)

Expand Down Expand Up @@ -1329,9 +1325,9 @@ def match_samples_batch_init(
sample_times = sample_times.tolist()
wd.sample_indexes = sample_indexes
wd.sample_times = sample_times
num_samples_per_partition = int(min_work_per_job // variant_data.num_sites)
if num_samples_per_partition == 0:
num_samples_per_partition = 1
num_samples_per_partition = max(
1, math.ceil(min_work_per_job // variant_data.num_sites)
)
wd.num_samples_per_partition = num_samples_per_partition
wd.num_partitions = math.ceil(len(sample_indexes) / num_samples_per_partition)
wd_path = work_dir / "metadata.json"
Expand Down

0 comments on commit 1aa0233

Please sign in to comment.