diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..b387539 --- /dev/null +++ b/.flake8 @@ -0,0 +1,6 @@ +[flake8] +# Based directly on Black's recommendations: +# https://black.readthedocs.io/en/stable/the_black_code_style.html#line-length +max-line-length = 81 +select = A,C,E,F,W,B,B950 +ignore = E203, E501, W503 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a5565c2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-merge-conflict + - id: debug-statements + - id: mixed-line-ending + - id: check-case-conflict + - id: check-yaml + - repo: https://github.com/asottile/reorder_python_imports + rev: v3.14.0 + hooks: + - id: reorder-python-imports + - repo: https://github.com/asottile/pyupgrade + rev: v3.19.0 + hooks: + - id: pyupgrade + args: [--py39-plus] + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + language_version: python3 + - repo: https://github.com/pycqa/flake8 + rev: 7.1.1 + hooks: + - id: flake8 + args: [--config=.flake8] + additional_dependencies: ["flake8-bugbear==24.12.12", "flake8-builtins==2.5.0"] diff --git a/tests/test_methods.py b/tests/test_methods.py index 9933bb0..e4654f0 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -23,7 +23,6 @@ """ Test tools for mapping between node sets of different tree sequences """ - from collections import defaultdict from itertools import combinations @@ -99,10 +98,13 @@ def naive_compare(ts, other, transform=None): Ineffiecient but transparent function to compute dissimilarity and root-mean-square-error between two tree sequences. """ + def f(t): return np.log(1 + t) - if transform is not None: - f = transform + + if transform is None: + transform = f + shared_spans = naive_shared_node_spans(ts, other).toarray() max_span = np.max(shared_spans, axis=1) assert len(max_span) == ts.num_nodes @@ -115,7 +117,9 @@ def f(t): else: for j in range(other.num_nodes): if shared_spans[i, j] == max_span[i]: - time_array[i, j] = np.abs(f(ts.nodes_time[i]) - f(other.nodes_time[j])) + time_array[i, j] = np.abs( + transform(ts.nodes_time[i]) - transform(other.nodes_time[j]) + ) dissimilarity_matrix[i, j] = 1 / (1 + time_array[i, j]) best_match = np.argmax(dissimilarity_matrix, axis=1) best_match_spans = np.zeros((ts.num_nodes,)) @@ -180,9 +184,7 @@ def test_node_spans(self, ts): naive_ns = naive_node_span(ts) assert np.allclose(eval_ns, naive_ns) - @pytest.mark.parametrize( - "pair", combinations([true_simpl, true_unary], 2) - ) + @pytest.mark.parametrize("pair", combinations([true_simpl, true_unary], 2)) def test_shared_spans(self, pair): """ Check that efficient implementation returns same answer as naive @@ -205,13 +207,16 @@ def test_match_self(self, ts): assert np.allclose(time, ts.nodes_time) assert np.array_equal(hit, np.arange(ts.num_nodes)) + class TestDissimilarity: def verify_compare(self, ts, other, transform=None): - match_span, ts_span, other_span, rmse = naive_compare(ts, other, transform=transform) + match_span, ts_span, other_span, rmse = naive_compare( + ts, other, transform=transform + ) dis = tscompare.compare(ts, other, transform=transform) - assert np.isclose(1.0 - match_span/ts_span, dis.arf) - assert np.isclose(match_span/other_span, dis.tpr) + assert np.isclose(1.0 - match_span / ts_span, dis.arf) + assert np.isclose(match_span / other_span, dis.tpr) assert np.isclose(ts_span - match_span, dis.dissimilarity) assert np.isclose(ts_span, dis.total_span[0]) assert np.isclose(other_span, dis.total_span[1]) @@ -235,7 +240,7 @@ def test_basic_comparison(self, pair): def test_zero_dissimilarity(self, pair): dis = tscompare.compare(pair[0], pair[1]) assert np.isclose(dis.dissimilarity, 0) - assert np.isclose(dis.arf, 0) + assert np.isclose(dis.arf, 0) assert np.isclose(dis.rmse, 0) def test_transform(self): @@ -243,7 +248,7 @@ def test_transform(self): dis2 = tscompare.compare(true_simpl, true_simpl, transform=None) assert dis1.dissimilarity == dis2.dissimilarity assert dis1.rmse == dis2.rmse - self.verify_compare(true_simpl, true_ext, transform=lambda t: 1/(1 + t)) + self.verify_compare(true_simpl, true_ext, transform=lambda t: 1 / (1 + t)) def get_simple_ts(self, samples=None, time=False, span=False, no_match=False): # A simple tree sequence we can use to properly test various @@ -397,12 +402,17 @@ def test_rmse(self): true_total_span = 46 assert dis.total_span[0] == true_total_span assert dis.total_span[1] == true_total_span + def f(t): return np.log(1 + t) - true_rmse = np.sqrt(( - 2 * 6 * (f(500) - f(200))**2 # nodes 4, 5 - + 2 * 2 * (f(750) - f(600))**2 # nodes, 7, 8 - ) / true_total_span) + + true_rmse = np.sqrt( + ( + 2 * 6 * (f(500) - f(200)) ** 2 # nodes 4, 5 + + 2 * 2 * (f(750) - f(600)) ** 2 # nodes, 7, 8 + ) + / true_total_span + ) assert np.isclose(dis.arf, 0.0) assert np.isclose(dis.tpr, 1.0) assert np.isclose(dis.dissimilarity, 0.0) @@ -414,12 +424,17 @@ def test_value_and_error(self): dis = tscompare.compare(ts, other) true_total_spans = (46, 47) assert dis.total_span == true_total_spans + def f(t): return np.log(1 + t) - true_rmse = np.sqrt(( - 2 * 6 * (f(500) - f(200))**2 # nodes 4, 5 - + 2 * 2 * (f(750) - f(600))**2 # nodes, 7, 8 - ) / true_total_spans[0]) + + true_rmse = np.sqrt( + ( + 2 * 6 * (f(500) - f(200)) ** 2 # nodes 4, 5 + + 2 * 2 * (f(750) - f(600)) ** 2 # nodes, 7, 8 + ) + / true_total_spans[0] + ) assert np.isclose(dis.arf, 4 / true_total_spans[0]) assert np.isclose(dis.tpr, (true_total_spans[0] - 4) / true_total_spans[1]) assert np.isclose(dis.dissimilarity, 4) diff --git a/tscompare/__init__.py b/tscompare/__init__.py index c88e8f7..35777ef 100644 --- a/tscompare/__init__.py +++ b/tscompare/__init__.py @@ -22,5 +22,10 @@ """ Tools for comparing tree sequences """ -from .methods import compare, node_spans, CladeMap, shared_node_spans, match_node_ages, ARFResult -from .provenance import __version__ +from .methods import ARFResult # noqa F401 +from .methods import CladeMap # noqa F401 +from .methods import compare # noqa F401 +from .methods import match_node_ages # noqa F401 +from .methods import node_spans # noqa F401 +from .methods import shared_node_spans # noqa F401 +from .provenance import __version__ # noqa F401 diff --git a/tscompare/methods.py b/tscompare/methods.py index e8e2561..ee92669 100644 --- a/tscompare/methods.py +++ b/tscompare/methods.py @@ -22,17 +22,16 @@ """ Tools for comparing node times between tree sequences with different node sets """ - -from dataclasses import dataclass +import copy from collections import defaultdict -from itertools import groupby, product +from dataclasses import dataclass +from itertools import product -import copy import numpy as np import scipy.sparse - import tskit + def node_spans(ts): """ Returns the array of "node spans", i.e., the `j`th entry gives @@ -97,7 +96,7 @@ def _propagate(self, edge, downdate=False): node = self.tree.parent(node) return nodes - def next(self): + def next(self): # noqa: A003 """ Advance to the next tree, returning the difference between trees as a dictionary of the form `node : (last_clade, next_clade)` @@ -254,18 +253,18 @@ def shared_node_spans(ts, other): def match_node_ages(ts, other): """ - For each node in `ts`, return the age of a matched node from `other`. Node - matching is accomplished as described in :func:`.compare`. - + For each node in `ts`, return the age of a matched node from `other`. Node + matching is accomplished as described in :func:`.compare`. + - Returns a tuple of three vectors of length `ts.num_nodes`, in this order: - the age of the best matching node in `other`; - the proportion of the node span in `ts` that is covered by the best match; - and the node id of the best match in `other`. + Returns a tuple of three vectors of length `ts.num_nodes`, in this order: + the age of the best matching node in `other`; + the proportion of the node span in `ts` that is covered by the best match; + and the node id of the best match in `other`. -:return: A tuple of arrays of length `ts.num_nodes` containing - (time of matching node, proportion overlap, and node ID of match). + :return: A tuple of arrays of length `ts.num_nodes` containing + (time of matching node, proportion overlap, and node ID of match). """ shared_spans = shared_node_spans(ts, other) @@ -283,7 +282,6 @@ def match_node_ages(ts, other): @dataclass class ARFResult: - """ The result of a call to tscompare.compare(ts, other), returning metrics associated with the ARG Robinson-Foulds @@ -302,7 +300,7 @@ class ARFResult: `dissimilarity`: The total span of `ts` that is not represented in `other`. - + `total_span`: The total of all node spans of the two tree sequences, in order (`ts`, `other`). @@ -314,6 +312,7 @@ class ARFResult: `transform`: The transformation function used to transform times for computing `rmse`. """ + arf: float tpr: float dissimilarity: float @@ -326,28 +325,29 @@ def __str__(self): Return a plain text summary of the ARF result. """ out = "Tree sequence comparison:\n" - out += f" ARF: {100*self.arf:.2f}%\n" - out += f" TPR: {100*self.tpr:.2f}%\n" + out += f" ARF: {100 * self.arf:.2f}%\n" + out += f" TPR: {100 * self.tpr:.2f}%\n" out += f" dissimilarity: {self.dissimilarity}\n" - out += f" total span (ts, other): {self.total_span[0]}, {self.total_span[1]}\n" + out += ( + f" total span (ts, other): {self.total_span[0]}, {self.total_span[1]}\n" + ) out += f" time RMSE: {self.rmse}\n" return out def compare(ts, other, transform=None): - """ For two tree sequences `ts` and `other`, this method returns an object of type :class:`.ARFResult`. The values reported summarize the degree to which nodes in `ts` "match" corresponding nodes in `other`. - + To match nodes, for each node in `ts`, the best matching node(s) from `other` has the longest matching span using :func:`.shared_node_spans`. If there are multiple matches with the same longest shared span for a single node, the best match is the match that is closest in time. - + Then, :class:`.ARFResult` contains: - (`dissimilarity`) @@ -356,8 +356,8 @@ def compare(ts, other, transform=None): samples as its best match in `other`. - (`arf`) - The fraction of the total span of `ts` over which each nodes' - descendant sample set does not match its' best match's descendant + The fraction of the total span of `ts` over which each nodes' + descendant sample set does not match its' best match's descendant sample set (i.e., the total *un*-matched span divided by the total span of `ts`). @@ -387,8 +387,11 @@ def compare(ts, other, transform=None): :rtype: ARFResult """ + def f(t): + return np.log(1 + t) + if transform is None: - transform = lambda t: np.log(1 + t) + transform = f shared_spans = shared_node_spans(ts, other) # Find all potential matches for a node based on max shared span length @@ -403,7 +406,9 @@ def compare(ts, other, transform=None): # determine best matches with the best_match_matrix ts_times = ts.nodes_time[row_ind[match]] other_times = other.nodes_time[col_ind[match]] - time_difference = np.absolute(np.asarray(transform(ts_times) - transform(other_times))) + time_difference = np.absolute( + np.asarray(transform(ts_times) - transform(other_times)) + ) # If a node x in `ts` has no match then we set time_difference to zero # This node then does not effect the rmse for j in range(len(shared_spans.data[match])): @@ -438,13 +443,10 @@ def compare(ts, other, transform=None): product = np.multiply((time_discrepancies**2), ts_node_spans) rmse = np.sqrt(np.sum(product) / total_span_ts) return ARFResult( - - arf = 1.0 - total_match_span / total_span_ts, - tpr = total_match_span / total_span_other, - - dissimilarity = total_span_ts - total_match_span, - total_span = (total_span_ts, total_span_other), - rmse = rmse, - transform = transform, + arf=1.0 - total_match_span / total_span_ts, + tpr=total_match_span / total_span_other, + dissimilarity=total_span_ts - total_match_span, + total_span=(total_span_ts, total_span_other), + rmse=rmse, + transform=transform, ) - diff --git a/tscompare/provenance.py b/tscompare/provenance.py index e6ed3d3..ddd0850 100644 --- a/tscompare/provenance.py +++ b/tscompare/provenance.py @@ -32,4 +32,3 @@ __version__ = get_version(root="..", relative_to=__file__) except ImportError: pass -