Skip to content

Commit

Permalink
Add partition matching
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed May 29, 2024
1 parent e03cc88 commit 2114d80
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 22 deletions.
24 changes: 24 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,30 @@ def test_equivalance(self, tmp_path, tmpdir):
ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)

def test_equivalance_with_partitions(self, tmp_path, tmpdir):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.SgkitSampleData(zarr_path)
ancestors = tsinfer.generate_ancestors(
samples, path=str(tmpdir / "ancestors.zarr")
)
metadata = tsinfer.match_ancestors_batch_init(
tmpdir / "work", zarr_path, tmpdir / "ancestors.zarr", 1000
)
for group_index, group in enumerate(metadata["ancestor_grouping"]):
if group["partitions"] is None:
tsinfer.match_ancestors_batch_group(tmpdir / "work", group_index)
else:
for p_index, _ in enumerate(group["partitions"]):
tsinfer.match_ancestors_batch_group_partition(
tmpdir / "work", group_index, p_index
)
ts = tsinfer.match_ancestors_batch_group_finalise(
tmpdir / "work", group_index
)
ts = tsinfer.match_ancestors_batch_finalise(tmpdir / "work")
ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)


class TestAncestorGeneratorsEquivalant:
"""
Expand Down
128 changes: 106 additions & 22 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,23 +579,32 @@ def match_ancestors_batch_init(
)
ancestor_grouping = []
ancestor_lengths = ancestors.ancestors_length
for group_ancestors in matcher.group_by_linesweep().values():
for group_index, group_ancestors in matcher.group_by_linesweep().items():
# Make ancestor_ids JSON serialisable
group_ancestors = list(map(int, group_ancestors))
partitions = []
current_partition = []
current_partition_work = 0
# TODO: Can do better here by packing ancestors
# into as equal sized partitions as possible
for ancestor in group_ancestors:
if current_partition_work + ancestor_lengths[ancestor] > min_work_per_job:
partitions.append(current_partition)
current_partition = [ancestor]
current_partition_work = ancestor_lengths[ancestor]
else:
current_partition.append(ancestor)
current_partition_work += ancestor_lengths[ancestor]
partitions.append(current_partition)
if group_index == 0:
partitions.append(group_ancestors)
else:
for ancestor in group_ancestors:
if (
current_partition_work + ancestor_lengths[ancestor]
> min_work_per_job
):
partitions.append(current_partition)
current_partition = [ancestor]
current_partition_work = ancestor_lengths[ancestor]
else:
current_partition.append(ancestor)
current_partition_work += ancestor_lengths[ancestor]
partitions.append(current_partition)
# Make directories for the path data
if len(partitions) > 1:
os.mkdir(os.path.join(working_dir, f"group_{group_index}"))
group = {
"ancestors": group_ancestors,
"partitions": partitions if len(partitions) > 1 else None,
Expand Down Expand Up @@ -623,22 +632,12 @@ def match_ancestors_batch_init(
return metadata


def match_ancestors_batch_group(work_dir, group_index):
metadata_path = os.path.join(work_dir, "metadata.json")
with open(metadata_path) as f:
metadata = json.load(f)
group = metadata["ancestor_grouping"][group_index]
def initialize_matcher(metadata, ancestors_ts=None):
sample_data = formats.SgkitSampleData(metadata["sample_data_path"])
ancestors = formats.AncestorData.load(metadata["ancestor_data_path"])
sample_data._check_finalised()
ancestors._check_finalised()
if group_index == 0:
ancestors_ts = None
else:
ancestors_ts = tskit.load(
os.path.join(work_dir, f"ancestors_{group_index-1}.trees")
)
matcher = AncestorMatcher(
return AncestorMatcher(
sample_data,
ancestors,
ancestors_ts=ancestors_ts,
Expand All @@ -652,11 +651,73 @@ def match_ancestors_batch_group(work_dir, group_index):
extended_checks=metadata["extended_checks"],
engine=metadata["engine"],
)


def match_ancestors_batch_group(work_dir, group_index):
metadata_path = os.path.join(work_dir, "metadata.json")
with open(metadata_path) as f:
metadata = json.load(f)
group = metadata["ancestor_grouping"][group_index]
if group_index == 0:
ancestors_ts = None
else:
ancestors_ts = tskit.load(
os.path.join(work_dir, f"ancestors_{group_index-1}.trees")
)
matcher = initialize_matcher(metadata, ancestors_ts)
ts = matcher.match_ancestors({group_index: group["ancestors"]})
ts.dump(os.path.join(work_dir, f"ancestors_{group_index}.trees"))
return ts


def match_ancestors_batch_group_partition(work_dir, group_index, partition_index):
metadata_path = os.path.join(work_dir, "metadata.json")
with open(metadata_path) as f:
metadata = json.load(f)
group = metadata["ancestor_grouping"][group_index]
if group["partitions"] is None:
raise ValueError(f"Group {group_index} has no partitions")

Check warning on line 679 in tsinfer/inference.py

View check run for this annotation

Codecov / codecov/patch

tsinfer/inference.py#L679

Added line #L679 was not covered by tests
if partition_index >= len(group["partitions"]) or partition_index < 0:
raise ValueError(f"Partition {partition_index} is out of range")

Check warning on line 681 in tsinfer/inference.py

View check run for this annotation

Codecov / codecov/patch

tsinfer/inference.py#L681

Added line #L681 was not covered by tests

ancestors_ts = tskit.load(
os.path.join(work_dir, f"ancestors_{group_index-1}.trees")
)
matcher = initialize_matcher(metadata, ancestors_ts)
ancestors_to_match = group["partitions"][partition_index]

results = matcher.match_partition(ancestors_to_match, group_index, partition_index)
partition_path = os.path.join(
work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl"
)
with open(partition_path, "wb") as f:
pickle.dump(results, f)


def match_ancestors_batch_group_finalise(work_dir, group_index):
metadata_path = os.path.join(work_dir, "metadata.json")
with open(metadata_path) as f:
metadata = json.load(f)
group = metadata["ancestor_grouping"][group_index]
ancestors_ts = tskit.load(
os.path.join(work_dir, f"ancestors_{group_index-1}.trees")
)
matcher = initialize_matcher(metadata, ancestors_ts)
logger.info(
f"Finalising group {group_index}, loading {len(group['partitions'])} partitions"
)
results = []
for partition_index in range(len(group["partitions"])):
partition_path = os.path.join(
work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl"
)
with open(partition_path, "rb") as f:
results.extend(pickle.load(f))
ts = matcher.finalise_group(group, results, group_index)
ts.dump(os.path.join(work_dir, f"ancestors_{group_index}.trees"))
return ts


def match_ancestors_batch_finalise(work_dir):
metadata_path = os.path.join(work_dir, "metadata.json")
with open(metadata_path) as f:
Expand Down Expand Up @@ -1916,6 +1977,29 @@ def match_ancestors(self, ancestor_grouping):
logger.info("Finished ancestor matching")
return ts

def match_partition(self, ancestors_to_match, group_index, partition_index):
logger.info(
f"Matching group {group_index} partition {partition_index} "
f"with {len(ancestors_to_match)} ancestors"
)
t = time_.time()
self.__start_group(group_index, ancestors_to_match)
self.match_progress = self.progress_monitor.get(
"ma_match", len(ancestors_to_match)
)
results = self.match_locally(ancestors_to_match)
self.match_progress.close()
logger.info(f"Matching took {time_.time() - t:.2f} seconds")
return results

def finalise_group(self, group, results, group_index):
logger.info(f"Finalising group {group_index}")
self.__start_group(group_index, group["ancestors"])
self.__complete_group(group_index, group["ancestors"], results)
ts = self.store_output()
logger.info(f"Finalised group {group_index}")
return ts

def get_ancestors_tables(self):
"""
Return the ancestors tree sequence tables. Only inference sites are included in
Expand Down

0 comments on commit 2114d80

Please sign in to comment.