Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve batch-match coverage #998

Merged
merged 6 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/large_scale.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ of ancestors. There are three API methods that work together to enable distribut
3. {meth}`match_samples_batch_finalise`

{meth}`match_samples_batch_init` should be called to set up the batch matching and to determine the
groupings of samples. Similar to {meth}`match_ancestors_batch_init` is has a `min_work_per_job` and
`max_num_partitions` arguments to control the level of parallelism. The method writes a file
groupings of samples. Similar to {meth}`match_ancestors_batch_init` it has a `min_work_per_job` argument to control the level of parallelism. The method writes a file
`metadata.json` to the directory `work_dir` that contains a JSON encoded dictionary with
configuration for later steps. This is also returned by the call. The `num_partitions` key in
this dictionary is the number of times {meth}`match_samples_batch_partition` will need
Expand Down
126 changes: 118 additions & 8 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
import msprime
import numpy as np
import pytest
import sgkit
import tskit
import tsutil
import xarray as xr
from tskit import MetadataSchema

import _tsinfer
Expand Down Expand Up @@ -1401,16 +1403,16 @@ def test_equivalance_many_at_once(self, tmp_path, tmpdir):
tmpdir / "ancestors.zarr",
1000,
)
tsinfer.match_ancestors_batch_groups(
tmpdir / "work", 0, len(metadata["ancestor_grouping"]) // 2, 2
)
num_groupings = len(metadata["ancestor_grouping"])
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 0, num_groupings // 2, 2)
tsinfer.match_ancestors_batch_groups(
tmpdir / "work",
len(metadata["ancestor_grouping"]) // 2,
len(metadata["ancestor_grouping"]),
num_groupings // 2,
num_groupings,
2,
)
# TODO Check which ones written to disk
assert (tmpdir / "work" / f"ancestors_{(num_groupings//2)-1}.trees").exists()
assert (tmpdir / "work" / f"ancestors_{num_groupings-1}.trees").exists()
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)
Expand Down Expand Up @@ -1438,6 +1440,11 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir):
tsinfer.match_ancestors_batch_group_partition(
tmpdir / "work", group_index, p_index
)
with pytest.raises(ValueError, match="out of range"):
tsinfer.match_ancestors_batch_group_partition(
tmpdir / "work", group_index, p_index + 1000
)

ts = tsinfer.match_ancestors_batch_group_finalise(
tmpdir / "work", group_index
)
Expand Down Expand Up @@ -1523,6 +1530,34 @@ def test_errors(self, tmp_path, tmpdir):
with pytest.raises(ValueError, match="sequence length is different"):
tsinfer.match_ancestors_batch_groups(tmpdir / "work", 2, 3)

def test_low_min_work_per_job(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
_ = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work",
zarr_path,
"variant_ancestral_allele",
tmpdir / "ancestors.zarr",
min_work_per_job=1,
max_num_partitions=2,
)
for group in metadata["ancestor_grouping"]:
assert group["partitions"] is None or len(group["partitions"]) <= 2

metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work2",
zarr_path,
"variant_ancestral_allele",
tmpdir / "ancestors.zarr",
min_work_per_job=1,
max_num_partitions=20000,
)
for group in metadata["ancestor_grouping"]:
if group["partitions"] is not None:
for partition in group["partitions"]:
assert len(partition) == 1


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
class TestBatchSampleMatching:
Expand All @@ -1543,8 +1578,8 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "mat_anc.trees",
min_work_per_job=1,
max_num_partitions=10,
)
assert mat_wd.num_partitions == mat_sd.num_samples
for i in range(mat_wd.num_partitions):
tsinfer.match_samples_batch_partition(
work_dir=tmpdir / "working_mat",
Expand All @@ -1564,7 +1599,6 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "mask_anc.trees",
min_work_per_job=1,
max_num_partitions=10,
site_mask="variant_mask_foobar",
sample_mask="samples_mask_foobar",
)
Expand All @@ -1588,6 +1622,82 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
mat_ts_batch.tables, ignore_timestamps=True, ignore_provenance=True
)

def test_force_sample_times(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
array = [0.0001] * ts.num_individuals
ds.update(
{
"individuals_time": xr.DataArray(
data=array, dims=["sample"], name="individuals_time"
)
}
)
sgkit.save_dataset(
ds.drop_vars(set(ds.data_vars) - {"individuals_time"}), zarr_path, mode="a"
)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
anc = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
anc_ts = tsinfer.match_ancestors(samples, anc)
anc_ts.dump(tmpdir / "anc.trees")

wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working",
sample_data_path=samples.path,
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "anc.trees",
min_work_per_job=1e6,
force_sample_times=True,
)
for i in range(wd.num_partitions):
tsinfer.match_samples_batch_partition(
work_dir=tmpdir / "working",
partition_index=i,
)
ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working")
ts = tsinfer.match_samples(samples, anc_ts, force_sample_times=True)
ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True)

def test_array_args(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
sample_mask = np.zeros(ts.num_individuals, dtype=bool)
sample_mask[42] = True
site_mask = np.zeros(ts.num_sites, dtype=bool)
site_mask[42] = True
rng = np.random.RandomState(42)
sites_time = rng.uniform(0, 1, ts.num_sites - 1)
samples = tsinfer.VariantData(
zarr_path,
"variant_ancestral_allele",
sample_mask=sample_mask,
site_mask=site_mask,
sites_time=sites_time,
)
anc = tsinfer.generate_ancestors(samples, path=str(tmpdir / "ancestors.zarr"))
anc_ts = tsinfer.match_ancestors(samples, anc)
anc_ts.dump(tmpdir / "anc.trees")

wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working",
sample_data_path=samples.path,
sample_mask=sample_mask,
site_mask=site_mask,
ancestral_state="variant_ancestral_allele",
ancestor_ts_path=tmpdir / "anc.trees",
min_work_per_job=1e6,
)
for i in range(wd.num_partitions):
tsinfer.match_samples_batch_partition(
work_dir=tmpdir / "working",
partition_index=i,
)
ts_batch = tsinfer.match_samples_batch_finalise(tmpdir / "working")
ts = tsinfer.match_samples(
samples,
anc_ts,
)
ts.tables.assert_equals(ts_batch.tables, ignore_provenance=True)


class TestAncestorGeneratorsEquivalant:
"""
Expand Down
78 changes: 40 additions & 38 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import json
import logging
import math
import operator
import os
import pathlib
import pickle
Expand Down Expand Up @@ -714,34 +715,41 @@ def match_ancestors_batch_init(
for group_index, group_ancestors in matcher.group_by_linesweep().items():
# Make ancestor_ids JSON serialisable
group_ancestors = list(map(int, group_ancestors))
partitions = []
current_partition = []
current_partition_work = 0
# TODO: Can do better here by packing ancestors
# into as equal sized partitions as possible
# The first group is trivial so never partition
if group_index == 0:
partitions.append(group_ancestors)
partitions = [
group_ancestors,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stray comma causing profligate whitespace

]
else:
total_work = sum(ancestor_lengths[ancestor] for ancestor in group_ancestors)
min_work_per_job_group = min_work_per_job
if total_work / max_num_partitions > min_work_per_job:
min_work_per_job_group = total_work / max_num_partitions
for ancestor in group_ancestors:
if (
current_partition_work + ancestor_lengths[ancestor]
> min_work_per_job_group
):
partitions.append(current_partition)
current_partition = [ancestor]
current_partition_work = ancestor_lengths[ancestor]
else:
current_partition.append(ancestor)
current_partition_work += ancestor_lengths[ancestor]
partitions.append(current_partition)
parition_count = math.ceil(total_work / min_work_per_job)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo "paritition" -> partition

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. I'm sprinkling these in now to prove that a free-range human wrote the code.

if parition_count > max_num_partitions:
parition_count = max_num_partitions

# Partition into roughly equal sized bins (by work)
sorted_ancestors = sorted(
group_ancestors, key=lambda x: ancestor_lengths[x], reverse=True
)
partitions = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
partitions = []
partitions = [[] for _ in range(partition_count)]
partition_lengths = [0 for _ in range(partition_count)]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Superseded by the heap code.

partition_lengths = []
for _ in range(parition_count):
partitions.append([])
partition_lengths.append(0)

# Use greedy bin packing - place each ancestor in the bin with
# lowest total length
for ancestor in sorted_ancestors:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we use a heapq for this?

heap = [(0, []) for _ range(partition_count]
for ancestor in sorted_ancestors:
      sum_len, partition = heapq.heappop(heap)
      partition.append(ancestor)
      sum_len += ancestor_lengths[ancestor]
      heapq.heappush(heap, (sum_len, partition))

I think this does the same thing, but avoids the quadratic time complexity here.

Copy link
Member Author

@benjeffery benjeffery Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, I should have thought of this!

min_length_idx = min(
range(len(partition_lengths)), key=lambda i: partition_lengths[i]
)
partitions[min_length_idx].append(ancestor)
partition_lengths[min_length_idx] += ancestor_lengths[ancestor]
partitions = [sorted(p) for p in partitions if len(p) > 0]

if len(partitions) > 1:
group_dir = work_dir / f"group_{group_index}"
group_dir.mkdir()
# TODO: Should be a dataclass

group = {
"ancestors": group_ancestors,
"partitions": partitions if len(partitions) > 1 else None,
Expand Down Expand Up @@ -902,7 +910,7 @@ def match_ancestors_batch_group_partition(work_dir, group_index, partition_index
)
logger.info(f"Dumping to {partition_path}")
with open(partition_path, "wb") as f:
pickle.dump((start_time, timing.metrics, results), f)
pickle.dump((start_time, timing.metrics, ancestors_to_match, results), f)


def match_ancestors_batch_group_finalise(work_dir, group_index):
Expand Down Expand Up @@ -935,17 +943,18 @@ def match_ancestors_batch_group_finalise(work_dir, group_index):
)
start_times = []
timings = []
results = []
results = {}
for partition_index in range(len(group["partitions"])):
partition_path = os.path.join(
work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl"
)
with open(partition_path, "rb") as f:
start_time, part_timing, result = pickle.load(f)
start_time, part_timing, ancestors, result = pickle.load(f)
start_times.append(start_time)
results.extend(result)
for ancestor, r in zip(ancestors, result):
results[ancestor] = r
timings.append(part_timing)

results = list(map(operator.itemgetter(1), sorted(results.items())))
ts = matcher.finalise_group(group, results, group_index)
path = os.path.join(work_dir, f"ancestors_{group_index}.trees")
ts.dump(path)
Expand Down Expand Up @@ -1186,7 +1195,6 @@ def match_samples_batch_init(
ancestor_ts_path,
min_work_per_job,
*,
max_num_partitions=None,
sample_mask=None,
site_mask=None,
recombination_rate=None,
Expand All @@ -1206,7 +1214,7 @@ def match_samples_batch_init(
):
"""
match_samples_batch_init(work_dir, sample_data_path, ancestral_state,
ancestor_ts_path, min_work_per_job, \\*, max_num_partitions=None,
ancestor_ts_path, min_work_per_job, \\*,
sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None,
path_compression=True, indexes=None, post_process=None, force_sample_times=False)

Expand Down Expand Up @@ -1237,9 +1245,6 @@ def match_samples_batch_init(
genotypes) to allocate to a single parallel job. If the amount of work in
a group of samples exceeds this level it will be broken up into parallel
partitions, subject to the constraint of `max_num_partitions`.
:param int max_num_partitions: The maximum number of partitions to split a
group of samples into. Useful for limiting the number of jobs in a
workflow to avoid job overhead. Defaults to 1000.
:param Union(array, str) sample_mask: A numpy array of booleans specifying
which samples to mask out (exclude) from the dataset. Alternatively, a
string can be provided, giving the name of an array in the input dataset
Expand Down Expand Up @@ -1277,9 +1282,6 @@ def match_samples_batch_init(
:return: A dictionary of the job metadata, as written to `metadata.json` in
`work_dir`.
"""
if max_num_partitions is None:
max_num_partitions = 1000

# Convert working_dir to pathlib.Path
work_dir = pathlib.Path(work_dir)

Expand Down Expand Up @@ -1329,9 +1331,9 @@ def match_samples_batch_init(
sample_times = sample_times.tolist()
wd.sample_indexes = sample_indexes
wd.sample_times = sample_times
num_samples_per_partition = int(min_work_per_job // variant_data.num_sites)
if num_samples_per_partition == 0:
num_samples_per_partition = 1
num_samples_per_partition = max(
1, math.ceil(min_work_per_job // variant_data.num_sites)
)
wd.num_samples_per_partition = num_samples_per_partition
wd.num_partitions = math.ceil(len(sample_indexes) / num_samples_per_partition)
wd_path = work_dir / "metadata.json"
Expand Down
Loading