-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve batch-match coverage #998
Changes from 5 commits
f9fba9e
bde2058
d9f8c71
629419b
f0f3ef0
a063456
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -28,6 +28,7 @@ | |||||||
import json | ||||||||
import logging | ||||||||
import math | ||||||||
import operator | ||||||||
import os | ||||||||
import pathlib | ||||||||
import pickle | ||||||||
|
@@ -714,34 +715,41 @@ def match_ancestors_batch_init( | |||||||
for group_index, group_ancestors in matcher.group_by_linesweep().items(): | ||||||||
# Make ancestor_ids JSON serialisable | ||||||||
group_ancestors = list(map(int, group_ancestors)) | ||||||||
partitions = [] | ||||||||
current_partition = [] | ||||||||
current_partition_work = 0 | ||||||||
# TODO: Can do better here by packing ancestors | ||||||||
# into as equal sized partitions as possible | ||||||||
# The first group is trivial so never partition | ||||||||
if group_index == 0: | ||||||||
partitions.append(group_ancestors) | ||||||||
partitions = [ | ||||||||
group_ancestors, | ||||||||
] | ||||||||
else: | ||||||||
total_work = sum(ancestor_lengths[ancestor] for ancestor in group_ancestors) | ||||||||
min_work_per_job_group = min_work_per_job | ||||||||
if total_work / max_num_partitions > min_work_per_job: | ||||||||
min_work_per_job_group = total_work / max_num_partitions | ||||||||
for ancestor in group_ancestors: | ||||||||
if ( | ||||||||
current_partition_work + ancestor_lengths[ancestor] | ||||||||
> min_work_per_job_group | ||||||||
): | ||||||||
partitions.append(current_partition) | ||||||||
current_partition = [ancestor] | ||||||||
current_partition_work = ancestor_lengths[ancestor] | ||||||||
else: | ||||||||
current_partition.append(ancestor) | ||||||||
current_partition_work += ancestor_lengths[ancestor] | ||||||||
partitions.append(current_partition) | ||||||||
parition_count = math.ceil(total_work / min_work_per_job) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo "paritition" -> partition There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. I'm sprinkling these in now to prove that a free-range human wrote the code. |
||||||||
if parition_count > max_num_partitions: | ||||||||
parition_count = max_num_partitions | ||||||||
|
||||||||
# Partition into roughly equal sized bins (by work) | ||||||||
sorted_ancestors = sorted( | ||||||||
group_ancestors, key=lambda x: ancestor_lengths[x], reverse=True | ||||||||
) | ||||||||
partitions = [] | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Superseded by the heap code. |
||||||||
partition_lengths = [] | ||||||||
for _ in range(parition_count): | ||||||||
partitions.append([]) | ||||||||
partition_lengths.append(0) | ||||||||
|
||||||||
# Use greedy bin packing - place each ancestor in the bin with | ||||||||
# lowest total length | ||||||||
for ancestor in sorted_ancestors: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about we use a heapq for this? heap = [(0, []) for _ range(partition_count]
for ancestor in sorted_ancestors:
sum_len, partition = heapq.heappop(heap)
partition.append(ancestor)
sum_len += ancestor_lengths[ancestor]
heapq.heappush(heap, (sum_len, partition)) I think this does the same thing, but avoids the quadratic time complexity here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice, I should have thought of this! |
||||||||
min_length_idx = min( | ||||||||
range(len(partition_lengths)), key=lambda i: partition_lengths[i] | ||||||||
) | ||||||||
partitions[min_length_idx].append(ancestor) | ||||||||
partition_lengths[min_length_idx] += ancestor_lengths[ancestor] | ||||||||
partitions = [sorted(p) for p in partitions if len(p) > 0] | ||||||||
|
||||||||
if len(partitions) > 1: | ||||||||
group_dir = work_dir / f"group_{group_index}" | ||||||||
group_dir.mkdir() | ||||||||
# TODO: Should be a dataclass | ||||||||
|
||||||||
group = { | ||||||||
"ancestors": group_ancestors, | ||||||||
"partitions": partitions if len(partitions) > 1 else None, | ||||||||
|
@@ -902,7 +910,7 @@ def match_ancestors_batch_group_partition(work_dir, group_index, partition_index | |||||||
) | ||||||||
logger.info(f"Dumping to {partition_path}") | ||||||||
with open(partition_path, "wb") as f: | ||||||||
pickle.dump((start_time, timing.metrics, results), f) | ||||||||
pickle.dump((start_time, timing.metrics, ancestors_to_match, results), f) | ||||||||
|
||||||||
|
||||||||
def match_ancestors_batch_group_finalise(work_dir, group_index): | ||||||||
|
@@ -935,17 +943,18 @@ def match_ancestors_batch_group_finalise(work_dir, group_index): | |||||||
) | ||||||||
start_times = [] | ||||||||
timings = [] | ||||||||
results = [] | ||||||||
results = {} | ||||||||
for partition_index in range(len(group["partitions"])): | ||||||||
partition_path = os.path.join( | ||||||||
work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl" | ||||||||
) | ||||||||
with open(partition_path, "rb") as f: | ||||||||
start_time, part_timing, result = pickle.load(f) | ||||||||
start_time, part_timing, ancestors, result = pickle.load(f) | ||||||||
start_times.append(start_time) | ||||||||
results.extend(result) | ||||||||
for ancestor, r in zip(ancestors, result): | ||||||||
results[ancestor] = r | ||||||||
timings.append(part_timing) | ||||||||
|
||||||||
results = list(map(operator.itemgetter(1), sorted(results.items()))) | ||||||||
ts = matcher.finalise_group(group, results, group_index) | ||||||||
path = os.path.join(work_dir, f"ancestors_{group_index}.trees") | ||||||||
ts.dump(path) | ||||||||
|
@@ -1186,7 +1195,6 @@ def match_samples_batch_init( | |||||||
ancestor_ts_path, | ||||||||
min_work_per_job, | ||||||||
*, | ||||||||
max_num_partitions=None, | ||||||||
sample_mask=None, | ||||||||
site_mask=None, | ||||||||
recombination_rate=None, | ||||||||
|
@@ -1206,7 +1214,7 @@ def match_samples_batch_init( | |||||||
): | ||||||||
""" | ||||||||
match_samples_batch_init(work_dir, sample_data_path, ancestral_state, | ||||||||
ancestor_ts_path, min_work_per_job, \\*, max_num_partitions=None, | ||||||||
ancestor_ts_path, min_work_per_job, \\*, | ||||||||
sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None, | ||||||||
path_compression=True, indexes=None, post_process=None, force_sample_times=False) | ||||||||
|
||||||||
|
@@ -1237,9 +1245,6 @@ def match_samples_batch_init( | |||||||
genotypes) to allocate to a single parallel job. If the amount of work in | ||||||||
a group of samples exceeds this level it will be broken up into parallel | ||||||||
partitions, subject to the constraint of `max_num_partitions`. | ||||||||
:param int max_num_partitions: The maximum number of partitions to split a | ||||||||
group of samples into. Useful for limiting the number of jobs in a | ||||||||
workflow to avoid job overhead. Defaults to 1000. | ||||||||
:param Union(array, str) sample_mask: A numpy array of booleans specifying | ||||||||
which samples to mask out (exclude) from the dataset. Alternatively, a | ||||||||
string can be provided, giving the name of an array in the input dataset | ||||||||
|
@@ -1277,9 +1282,6 @@ def match_samples_batch_init( | |||||||
:return: A dictionary of the job metadata, as written to `metadata.json` in | ||||||||
`work_dir`. | ||||||||
""" | ||||||||
if max_num_partitions is None: | ||||||||
max_num_partitions = 1000 | ||||||||
|
||||||||
# Convert working_dir to pathlib.Path | ||||||||
work_dir = pathlib.Path(work_dir) | ||||||||
|
||||||||
|
@@ -1329,9 +1331,9 @@ def match_samples_batch_init( | |||||||
sample_times = sample_times.tolist() | ||||||||
wd.sample_indexes = sample_indexes | ||||||||
wd.sample_times = sample_times | ||||||||
num_samples_per_partition = int(min_work_per_job // variant_data.num_sites) | ||||||||
if num_samples_per_partition == 0: | ||||||||
num_samples_per_partition = 1 | ||||||||
num_samples_per_partition = max( | ||||||||
1, math.ceil(min_work_per_job // variant_data.num_sites) | ||||||||
) | ||||||||
wd.num_samples_per_partition = num_samples_per_partition | ||||||||
wd.num_partitions = math.ceil(len(sample_indexes) / num_samples_per_partition) | ||||||||
wd_path = work_dir / "metadata.json" | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stray comma causing profligate whitespace