Skip to content

Commit

Permalink
Move to a dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Feb 10, 2025
1 parent fadaa35 commit 4743033
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 77 deletions.
2 changes: 1 addition & 1 deletion requirements/development.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ sgkit[vcf]
sphinx-book-theme
jupyter-book
sphinx-issues
ipywidgets
ipywidgets
10 changes: 2 additions & 8 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,10 +1447,7 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir):
prov = json.loads(ts.provenances()[-1].record)
assert "resources" in prov
# Check that the time taken was longer than finalise took
assert (
prov["resources"]["elapsed_time"]
> final_timing.get_metrics()["elapsed_time"]
)
assert prov["resources"]["elapsed_time"] > final_timing.metrics.elapsed_time

ts2 = tsinfer.match_ancestors(samples, ancestors)
ts.tables.assert_equals(ts2.tables, ignore_provenance=True)
Expand Down Expand Up @@ -1559,10 +1556,7 @@ def test_match_samples_batch(self, tmp_path, tmpdir):
prov = json.loads(mat_ts_batch.provenances()[-1].record)
assert "resources" in prov
# Check that the time taken was longer than finalise took
assert (
prov["resources"]["elapsed_time"]
> final_timing.get_metrics()["elapsed_time"]
)
assert prov["resources"]["elapsed_time"] > final_timing.metrics.elapsed_time

mask_wd = tsinfer.match_samples_batch_init(
work_dir=tmpdir / "working_mask",
Expand Down
60 changes: 52 additions & 8 deletions tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,51 @@ def test_ancestors_file(self, small_sd_fixture):
self.validate_file(ancestor_data)


class TestResourceMetrics:
"""
Tests for the ResourceMetrics dataclass.
"""

def test_create_and_asdict(self):
metrics = provenance.ResourceMetrics(
elapsed_time=1.5, user_time=1.0, sys_time=0.5, max_memory=1000
)
d = metrics.asdict()
assert d == {
"elapsed_time": 1.5,
"user_time": 1.0,
"sys_time": 0.5,
"max_memory": 1000,
}

def test_combine_metrics(self):
m1 = provenance.ResourceMetrics(
elapsed_time=1.0, user_time=0.5, sys_time=0.2, max_memory=1000
)
m2 = provenance.ResourceMetrics(
elapsed_time=2.0, user_time=1.5, sys_time=0.3, max_memory=2000
)
combined = provenance.ResourceMetrics.combine([m1, m2])
assert combined.elapsed_time == 3.0
assert combined.user_time == 2.0
assert combined.sys_time == 0.5
assert combined.max_memory == 2000

def test_combine_empty_list(self):
with pytest.raises(ValueError):
provenance.ResourceMetrics.combine([])

def test_combine_single_metric(self):
m = provenance.ResourceMetrics(
elapsed_time=1.0, user_time=0.5, sys_time=0.2, max_memory=1000
)
combined = provenance.ResourceMetrics.combine([m])
assert combined.elapsed_time == 1.0
assert combined.user_time == 0.5
assert combined.sys_time == 0.2
assert combined.max_memory == 1000


class TestIncludeProvenance:
"""
Test that we can include or exclude provenances
Expand Down Expand Up @@ -249,15 +294,14 @@ def test_timing_and_memory_context_manager():
math.sqrt(i)
_ = [0] * 1000000

metrics = timing.get_metrics()
assert metrics is not None
assert metrics["elapsed_time"] > 0.1
assert timing.metrics is not None
assert timing.metrics.elapsed_time > 0.1
# Check we have highres timing
assert metrics["elapsed_time"] < 1
assert metrics["user_time"] > 0
assert metrics["sys_time"] >= 0
assert metrics["max_memory"] > 100_000_000
assert timing.metrics.elapsed_time < 1
assert timing.metrics.user_time > 0
assert timing.metrics.sys_time >= 0
assert timing.metrics.max_memory > 100_000_000

# Test metrics are not available during context
with provenance.TimingAndMemory() as timing2:
assert timing2.get_metrics() is None
assert timing2.metrics is None
69 changes: 23 additions & 46 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def infer(
tables = inferred_ts.dump_tables()
record = provenance.get_provenance_dict(
command="infer",
resources=timing.get_metrics(),
resources=timing.metrics.asdict(),
mismatch_ratio=mismatch_ratio,
path_compression=path_compression,
precision=precision,
Expand Down Expand Up @@ -500,7 +500,7 @@ def generate_ancestors(
for timestamp, record in sample_data.provenances():
ancestor_data.add_provenance(timestamp, record)
if record_provenance:
ancestor_data.record_provenance("generate_ancestors", timing.get_metrics())
ancestor_data.record_provenance("generate_ancestors", timing.metrics.asdict())
ancestor_data.finalise()
return ancestor_data

Expand Down Expand Up @@ -587,7 +587,7 @@ def match_ancestors(
if record_provenance:
record = provenance.get_provenance_dict(
command="match_ancestors",
resources=timing.get_metrics(),
resources=timing.metrics.asdict(),
mismatch_ratio=mismatch_ratio,
path_compression=path_compression,
precision=precision,
Expand Down Expand Up @@ -765,7 +765,7 @@ def match_ancestors_batch_groups(
logger.info(f"Dumping to {path}")
ts.dump(path)
with open(path + ".resources", "w") as f:
f.write(json.dumps(timing.get_metrics()))
f.write(json.dumps(timing.metrics.asdict()))
return ts


Expand Down Expand Up @@ -794,10 +794,8 @@ def match_ancestors_batch_group_partition(work_dir, group_index, partition_index
work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl"
)
logger.info(f"Dumping to {partition_path}")
resources = timing.get_metrics()
resources["start_time"] = start_time
with open(partition_path, "wb") as f:
pickle.dump((resources, results), f)
pickle.dump((start_time, timing.metrics, results), f)


def match_ancestors_batch_group_finalise(work_dir, group_index):
Expand All @@ -814,33 +812,27 @@ def match_ancestors_batch_group_finalise(work_dir, group_index):
f"Finalising group {group_index}, loading "
f"{len(group['partitions'])} partitions"
)
results = []
start_times = []
timings = []
results = []
for partition_index in range(len(group["partitions"])):
partition_path = os.path.join(
work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl"
)
with open(partition_path, "rb") as f:
part_timing, result = pickle.load(f)
start_time, part_timing, result = pickle.load(f)
start_times.append(start_time)
results.extend(result)
timings.append(part_timing)

ts = matcher.finalise_group(group, results, group_index)
path = os.path.join(work_dir, f"ancestors_{group_index}.trees")
ts.dump(path)

finalise_metrics = timing.get_metrics()
final_resource_dict = {
"elapsed_time": time_.perf_counter() - min(t["start_time"] for t in timings),
"user_time": sum(t["user_time"] for t in timings)
+ finalise_metrics["user_time"],
"sys_time": sum(t["sys_time"] for t in timings) + finalise_metrics["sys_time"],
"max_memory": max(
finalise_metrics["max_memory"], *[t["max_memory"] for t in timings]
),
}
combined_metrics = provenance.ResourceMetrics.combine(timings + [timing.metrics])
combined_metrics.elapsed_time = time_.perf_counter() - min(start_times)
with open(path + ".resources", "w") as f:
f.write(json.dumps(final_resource_dict))
f.write(json.dumps(combined_metrics.asdict()))
return ts


Expand All @@ -862,22 +854,15 @@ def match_ancestors_batch_finalise(work_dir):
for file in files:
if file.endswith(".resources"):
with open(os.path.join(root, file)) as f:
resource = json.load(f)
resource = provenance.ResourceMetrics(**json.load(f))
resources.append(resource)
final_resources = timing.get_metrics()
resource_dict = {
"elapsed_time": time_.perf_counter() - metadata["start_time"],
"user_time": sum(r["user_time"] for r in resources)
+ final_resources["user_time"],
"sys_time": sum(r["sys_time"] for r in resources)
+ final_resources["sys_time"],
"max_memory": max(
final_resources["max_memory"], *[r["max_memory"] for r in resources]
),
}
combined_resources = provenance.ResourceMetrics.combine(
resources + [timing.metrics]
)
combined_resources.elapsed_time = time_.perf_counter() - metadata["start_time"]
record = provenance.get_provenance_dict(
command="match_ancestors",
resources=resource_dict,
resources=combined_resources.asdict(),
mismatch_ratio=metadata["mismatch_ratio"],
path_compression=metadata["path_compression"],
precision=metadata["precision"],
Expand Down Expand Up @@ -1173,7 +1158,7 @@ def match_samples_batch_partition(work_dir, partition_index):
path = os.path.join(work_dir, f"partition_{partition_index}.pkl")
logger.info(f"Dumping to {path}")
with open(path, "wb") as f:
pickle.dump((timing.get_metrics(), results), f)
pickle.dump((timing.metrics, results), f)


def match_samples_batch_finalise(work_dir):
Expand Down Expand Up @@ -1202,20 +1187,12 @@ def match_samples_batch_finalise(work_dir):
)
# Rewrite the last provenance with the correct info
start_time = wd.start_time
finalise_metrics = timing.get_metrics()
final_resource_dict = {
"elapsed_time": time_.perf_counter() - start_time,
"user_time": sum(t["user_time"] for t in timings)
+ finalise_metrics["user_time"],
"sys_time": sum(t["sys_time"] for t in timings) + finalise_metrics["sys_time"],
"max_memory": max(
finalise_metrics["max_memory"], *[t["max_memory"] for t in timings]
),
}
combined_metrics = provenance.ResourceMetrics.combine(timings + [timing.metrics])
combined_metrics.elapsed_time = time_.perf_counter() - start_time
tables = ts.dump_tables()
prov = tables.provenances[-1]
record = json.loads(prov.record)
record["resources"] = final_resource_dict
record["resources"] = combined_metrics.asdict()
tables.provenances[-1] = prov.replace(record=json.dumps(record))
return tables.tree_sequence()

Expand Down Expand Up @@ -1357,7 +1334,7 @@ def match_samples(
# We don't have a source here because tree sequence files don't have a UUID yet.
record = provenance.get_provenance_dict(
command="match_samples",
resources=timing.get_metrics(),
resources=timing.metrics.asdict(),
mismatch_ratio=mismatch_ratio,
path_compression=path_compression,
precision=precision,
Expand Down
46 changes: 32 additions & 14 deletions tsinfer/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Common provenance methods used to determine the state and versions
of various dependencies and the OS.
"""
import dataclasses
import platform
import resource
import sys
Expand All @@ -41,6 +42,28 @@
pass


@dataclasses.dataclass
class ResourceMetrics:
elapsed_time: float
user_time: float
sys_time: float
max_memory: int

def asdict(self):
return dataclasses.asdict(self)

@classmethod
def combine(cls, metrics_list):
if not metrics_list:
raise ValueError("Cannot combine empty list of metrics")
return cls(
elapsed_time=sum(m.elapsed_time for m in metrics_list),
user_time=sum(m.user_time for m in metrics_list),
sys_time=sum(m.sys_time for m in metrics_list),
max_memory=max(m.max_memory for m in metrics_list),
)


def get_environment():
"""
Returns a dictionary describing the environment in which tsinfer
Expand Down Expand Up @@ -101,12 +124,11 @@ def get_peak_memory_bytes():
usage = resource.getrusage(resource.RUSAGE_SELF)
max_rss = usage.ru_maxrss

if sys.platform == "darwin":
# macOS reports in bytes
return max_rss
else:
if sys.platform == "linux":
# Linux reports in kilobytes
return max_rss * 1024 # Convert KB to bytes
# macOS reports in bytes
return max_rss

Check warning on line 131 in tsinfer/provenance.py

View check run for this annotation

Codecov / codecov/patch

tsinfer/provenance.py#L131

Added line #L131 was not covered by tests

elif sys.platform == "win32":
return psutil.Process().memory_info().peak_wset

Check warning on line 134 in tsinfer/provenance.py

View check run for this annotation

Codecov / codecov/patch

tsinfer/provenance.py#L134

Added line #L134 was not covered by tests
Expand All @@ -128,13 +150,9 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
end_times = self.start_process.cpu_times()
self.metrics = {
"elapsed_time": time.perf_counter() - self.start_elapsed,
"user_time": end_times.user - self.start_times.user,
"sys_time": end_times.system - self.start_times.system,
"max_memory": get_peak_memory_bytes(),
}

def get_metrics(self):
"""Return the timing and memory metrics dictionary."""
return self.metrics
self.metrics = ResourceMetrics(
elapsed_time=time.perf_counter() - self.start_elapsed,
user_time=end_times.user - self.start_times.user,
sys_time=end_times.system - self.start_times.system,
max_memory=get_peak_memory_bytes(),
)

0 comments on commit 4743033

Please sign in to comment.