From 041b1d7707bc3cdba3c45cbb52654a3f59ec377a Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Mon, 10 Feb 2025 16:20:56 +0000 Subject: [PATCH] Move to a dataclass --- requirements/development.txt | 2 +- tests/test_inference.py | 10 ++---- tests/test_provenance.py | 60 ++++++++++++++++++++++++++----- tsinfer/inference.py | 69 ++++++++++++------------------------ tsinfer/provenance.py | 50 ++++++++++++++++++-------- 5 files changed, 113 insertions(+), 78 deletions(-) diff --git a/requirements/development.txt b/requirements/development.txt index c0d105a7..28e2cf1d 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -34,4 +34,4 @@ sgkit[vcf] sphinx-book-theme jupyter-book sphinx-issues -ipywidgets \ No newline at end of file +ipywidgets diff --git a/tests/test_inference.py b/tests/test_inference.py index 3afaf25e..724b997b 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1447,10 +1447,7 @@ def test_equivalance_with_partitions(self, tmp_path, tmpdir): prov = json.loads(ts.provenances()[-1].record) assert "resources" in prov # Check that the time taken was longer than finalise took - assert ( - prov["resources"]["elapsed_time"] - > final_timing.get_metrics()["elapsed_time"] - ) + assert prov["resources"]["elapsed_time"] > final_timing.metrics.elapsed_time ts2 = tsinfer.match_ancestors(samples, ancestors) ts.tables.assert_equals(ts2.tables, ignore_provenance=True) @@ -1559,10 +1556,7 @@ def test_match_samples_batch(self, tmp_path, tmpdir): prov = json.loads(mat_ts_batch.provenances()[-1].record) assert "resources" in prov # Check that the time taken was longer than finalise took - assert ( - prov["resources"]["elapsed_time"] - > final_timing.get_metrics()["elapsed_time"] - ) + assert prov["resources"]["elapsed_time"] > final_timing.metrics.elapsed_time mask_wd = tsinfer.match_samples_batch_init( work_dir=tmpdir / "working_mask", diff --git a/tests/test_provenance.py b/tests/test_provenance.py index 461c364e..59581102 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -70,6 +70,51 @@ def test_ancestors_file(self, small_sd_fixture): self.validate_file(ancestor_data) +class TestResourceMetrics: + """ + Tests for the ResourceMetrics dataclass. + """ + + def test_create_and_asdict(self): + metrics = provenance.ResourceMetrics( + elapsed_time=1.5, user_time=1.0, sys_time=0.5, max_memory=1000 + ) + d = metrics.asdict() + assert d == { + "elapsed_time": 1.5, + "user_time": 1.0, + "sys_time": 0.5, + "max_memory": 1000, + } + + def test_combine_metrics(self): + m1 = provenance.ResourceMetrics( + elapsed_time=1.0, user_time=0.5, sys_time=0.2, max_memory=1000 + ) + m2 = provenance.ResourceMetrics( + elapsed_time=2.0, user_time=1.5, sys_time=0.3, max_memory=2000 + ) + combined = provenance.ResourceMetrics.combine([m1, m2]) + assert combined.elapsed_time == 3.0 + assert combined.user_time == 2.0 + assert combined.sys_time == 0.5 + assert combined.max_memory == 2000 + + def test_combine_empty_list(self): + with pytest.raises(ValueError): + provenance.ResourceMetrics.combine([]) + + def test_combine_single_metric(self): + m = provenance.ResourceMetrics( + elapsed_time=1.0, user_time=0.5, sys_time=0.2, max_memory=1000 + ) + combined = provenance.ResourceMetrics.combine([m]) + assert combined.elapsed_time == 1.0 + assert combined.user_time == 0.5 + assert combined.sys_time == 0.2 + assert combined.max_memory == 1000 + + class TestIncludeProvenance: """ Test that we can include or exclude provenances @@ -249,15 +294,14 @@ def test_timing_and_memory_context_manager(): math.sqrt(i) _ = [0] * 1000000 - metrics = timing.get_metrics() - assert metrics is not None - assert metrics["elapsed_time"] > 0.1 + assert timing.metrics is not None + assert timing.metrics.elapsed_time > 0.1 # Check we have highres timing - assert metrics["elapsed_time"] < 1 - assert metrics["user_time"] > 0 - assert metrics["sys_time"] >= 0 - assert metrics["max_memory"] > 100_000_000 + assert timing.metrics.elapsed_time < 1 + assert timing.metrics.user_time > 0 + assert timing.metrics.sys_time >= 0 + assert timing.metrics.max_memory > 100_000_000 # Test metrics are not available during context with provenance.TimingAndMemory() as timing2: - assert timing2.get_metrics() is None + assert timing2.metrics is None diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 7733c1a4..de5fd9fc 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -381,7 +381,7 @@ def infer( tables = inferred_ts.dump_tables() record = provenance.get_provenance_dict( command="infer", - resources=timing.get_metrics(), + resources=timing.metrics.asdict(), mismatch_ratio=mismatch_ratio, path_compression=path_compression, precision=precision, @@ -500,7 +500,7 @@ def generate_ancestors( for timestamp, record in sample_data.provenances(): ancestor_data.add_provenance(timestamp, record) if record_provenance: - ancestor_data.record_provenance("generate_ancestors", timing.get_metrics()) + ancestor_data.record_provenance("generate_ancestors", timing.metrics.asdict()) ancestor_data.finalise() return ancestor_data @@ -587,7 +587,7 @@ def match_ancestors( if record_provenance: record = provenance.get_provenance_dict( command="match_ancestors", - resources=timing.get_metrics(), + resources=timing.metrics.asdict(), mismatch_ratio=mismatch_ratio, path_compression=path_compression, precision=precision, @@ -765,7 +765,7 @@ def match_ancestors_batch_groups( logger.info(f"Dumping to {path}") ts.dump(path) with open(path + ".resources", "w") as f: - f.write(json.dumps(timing.get_metrics())) + f.write(json.dumps(timing.metrics.asdict())) return ts @@ -794,10 +794,8 @@ def match_ancestors_batch_group_partition(work_dir, group_index, partition_index work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl" ) logger.info(f"Dumping to {partition_path}") - resources = timing.get_metrics() - resources["start_time"] = start_time with open(partition_path, "wb") as f: - pickle.dump((resources, results), f) + pickle.dump((start_time, timing.metrics, results), f) def match_ancestors_batch_group_finalise(work_dir, group_index): @@ -814,14 +812,16 @@ def match_ancestors_batch_group_finalise(work_dir, group_index): f"Finalising group {group_index}, loading " f"{len(group['partitions'])} partitions" ) - results = [] + start_times = [] timings = [] + results = [] for partition_index in range(len(group["partitions"])): partition_path = os.path.join( work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl" ) with open(partition_path, "rb") as f: - part_timing, result = pickle.load(f) + start_time, part_timing, result = pickle.load(f) + start_times.append(start_time) results.extend(result) timings.append(part_timing) @@ -829,18 +829,10 @@ def match_ancestors_batch_group_finalise(work_dir, group_index): path = os.path.join(work_dir, f"ancestors_{group_index}.trees") ts.dump(path) - finalise_metrics = timing.get_metrics() - final_resource_dict = { - "elapsed_time": time_.perf_counter() - min(t["start_time"] for t in timings), - "user_time": sum(t["user_time"] for t in timings) - + finalise_metrics["user_time"], - "sys_time": sum(t["sys_time"] for t in timings) + finalise_metrics["sys_time"], - "max_memory": max( - finalise_metrics["max_memory"], *[t["max_memory"] for t in timings] - ), - } + combined_metrics = provenance.ResourceMetrics.combine(timings + [timing.metrics]) + combined_metrics.elapsed_time = time_.perf_counter() - min(start_times) with open(path + ".resources", "w") as f: - f.write(json.dumps(final_resource_dict)) + f.write(json.dumps(combined_metrics.asdict())) return ts @@ -862,22 +854,15 @@ def match_ancestors_batch_finalise(work_dir): for file in files: if file.endswith(".resources"): with open(os.path.join(root, file)) as f: - resource = json.load(f) + resource = provenance.ResourceMetrics(**json.load(f)) resources.append(resource) - final_resources = timing.get_metrics() - resource_dict = { - "elapsed_time": time_.perf_counter() - metadata["start_time"], - "user_time": sum(r["user_time"] for r in resources) - + final_resources["user_time"], - "sys_time": sum(r["sys_time"] for r in resources) - + final_resources["sys_time"], - "max_memory": max( - final_resources["max_memory"], *[r["max_memory"] for r in resources] - ), - } + combined_resources = provenance.ResourceMetrics.combine( + resources + [timing.metrics] + ) + combined_resources.elapsed_time = time_.perf_counter() - metadata["start_time"] record = provenance.get_provenance_dict( command="match_ancestors", - resources=resource_dict, + resources=combined_resources.asdict(), mismatch_ratio=metadata["mismatch_ratio"], path_compression=metadata["path_compression"], precision=metadata["precision"], @@ -1173,7 +1158,7 @@ def match_samples_batch_partition(work_dir, partition_index): path = os.path.join(work_dir, f"partition_{partition_index}.pkl") logger.info(f"Dumping to {path}") with open(path, "wb") as f: - pickle.dump((timing.get_metrics(), results), f) + pickle.dump((timing.metrics, results), f) def match_samples_batch_finalise(work_dir): @@ -1202,20 +1187,12 @@ def match_samples_batch_finalise(work_dir): ) # Rewrite the last provenance with the correct info start_time = wd.start_time - finalise_metrics = timing.get_metrics() - final_resource_dict = { - "elapsed_time": time_.perf_counter() - start_time, - "user_time": sum(t["user_time"] for t in timings) - + finalise_metrics["user_time"], - "sys_time": sum(t["sys_time"] for t in timings) + finalise_metrics["sys_time"], - "max_memory": max( - finalise_metrics["max_memory"], *[t["max_memory"] for t in timings] - ), - } + combined_metrics = provenance.ResourceMetrics.combine(timings + [timing.metrics]) + combined_metrics.elapsed_time = time_.perf_counter() - start_time tables = ts.dump_tables() prov = tables.provenances[-1] record = json.loads(prov.record) - record["resources"] = final_resource_dict + record["resources"] = combined_metrics.asdict() tables.provenances[-1] = prov.replace(record=json.dumps(record)) return tables.tree_sequence() @@ -1357,7 +1334,7 @@ def match_samples( # We don't have a source here because tree sequence files don't have a UUID yet. record = provenance.get_provenance_dict( command="match_samples", - resources=timing.get_metrics(), + resources=timing.metrics.asdict(), mismatch_ratio=mismatch_ratio, path_compression=path_compression, precision=precision, diff --git a/tsinfer/provenance.py b/tsinfer/provenance.py index 503fe869..d6ca42d3 100644 --- a/tsinfer/provenance.py +++ b/tsinfer/provenance.py @@ -20,8 +20,8 @@ Common provenance methods used to determine the state and versions of various dependencies and the OS. """ +import dataclasses import platform -import resource import sys import time @@ -31,6 +31,9 @@ import tskit import zarr +if sys.platform != "win32": + import resource + __version__ = "undefined" try: @@ -41,6 +44,28 @@ pass +@dataclasses.dataclass +class ResourceMetrics: + elapsed_time: float + user_time: float + sys_time: float + max_memory: int + + def asdict(self): + return dataclasses.asdict(self) + + @classmethod + def combine(cls, metrics_list): + if not metrics_list: + raise ValueError("Cannot combine empty list of metrics") + return cls( + elapsed_time=sum(m.elapsed_time for m in metrics_list), + user_time=sum(m.user_time for m in metrics_list), + sys_time=sum(m.sys_time for m in metrics_list), + max_memory=max(m.max_memory for m in metrics_list), + ) + + def get_environment(): """ Returns a dictionary describing the environment in which tsinfer @@ -101,12 +126,11 @@ def get_peak_memory_bytes(): usage = resource.getrusage(resource.RUSAGE_SELF) max_rss = usage.ru_maxrss - if sys.platform == "darwin": - # macOS reports in bytes - return max_rss - else: + if sys.platform == "linux": # Linux reports in kilobytes return max_rss * 1024 # Convert KB to bytes + # macOS reports in bytes + return max_rss elif sys.platform == "win32": return psutil.Process().memory_info().peak_wset @@ -128,13 +152,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): end_times = self.start_process.cpu_times() - self.metrics = { - "elapsed_time": time.perf_counter() - self.start_elapsed, - "user_time": end_times.user - self.start_times.user, - "sys_time": end_times.system - self.start_times.system, - "max_memory": get_peak_memory_bytes(), - } - - def get_metrics(self): - """Return the timing and memory metrics dictionary.""" - return self.metrics + self.metrics = ResourceMetrics( + elapsed_time=time.perf_counter() - self.start_elapsed, + user_time=end_times.user - self.start_times.user, + sys_time=end_times.system - self.start_times.system, + max_memory=get_peak_memory_bytes(), + )