Skip to content

Commit

Permalink
**actually** removed ties arg
Browse files Browse the repository at this point in the history
  • Loading branch information
hfr1tz3 committed Jan 14, 2025
1 parent ed589f1 commit 001957d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 25 deletions.
37 changes: 13 additions & 24 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def naive_node_span(ts):
return node_spans


def naive_compare(ts, other, transform=None, ties="average"):
def naive_compare(ts, other, transform=None):
"""
Ineffiecient but transparent function to compute dissimilarity
and root-mean-square-error between two tree sequences.
Expand Down Expand Up @@ -125,10 +125,7 @@ def f(t):
best_match_spans = np.zeros((ts.num_nodes,))
time_discrepancies = np.zeros((ts.num_nodes,))
for i, j in enumerate(best_match):
if ties == 'average':
best_match_spans[i] = shared_spans[i, j]/np.bincount(best_match)[j]
if ties is None:
best_match_spans[i] = shared_spans[i, j]
best_match_spans[i] = shared_spans[i, j]/np.bincount(best_match)[j]
time_discrepancies[i] = time_array[i, j]
node_span = naive_node_span(ts)
total_node_spans = np.sum(node_span)
Expand Down Expand Up @@ -213,11 +210,11 @@ def test_match_self(self, ts):

class TestDissimilarity:

def verify_compare(self, ts, other, transform=None, ties="average"):
def verify_compare(self, ts, other, transform=None):
match_span, ts_span, other_span, rmse = naive_compare(
ts, other, transform=transform, ties=ties,
ts, other, transform=transform,
)
dis = tscompare.compare(ts, other, transform=transform, ties=ties)
dis = tscompare.compare(ts, other, transform=transform)
assert np.isclose(1.0 - match_span / ts_span, dis.arf)
assert np.isclose(match_span / other_span, dis.tpr)
assert np.isclose(ts_span - match_span, dis.dissimilarity)
Expand Down Expand Up @@ -247,20 +244,13 @@ def test_zero_dissimilarity(self, pair):
assert np.isclose(dis.rmse, 0)

def test_transform(self):
dis1 = tscompare.compare(true_simpl, true_simpl, transform=lambda t: t, ties=None)
dis2 = tscompare.compare(true_simpl, true_simpl, transform=None, ties=None)
dis1 = tscompare.compare(true_simpl, true_simpl, transform=lambda t: t)
dis2 = tscompare.compare(true_simpl, true_simpl, transform=None)
assert dis1.dissimilarity == dis2.dissimilarity
assert dis1.rmse == dis2.rmse
self.verify_compare(true_simpl, true_ext, transform=lambda t: 1 / (1 + t), ties=None)
self.verify_compare(true_simpl, true_ext, transform=lambda t: 1 / (1 + t))

def test_ties(self):
dis1 = tscompare.compare(true_simpl, true_ext, transform=None, ties="average")
dis2 = tscompare.compare(true_simpl, true_ext, transform=None, ties=None)
assert dis1.dissimilarity == dis2.dissimilarity
assert dis1.rmse == dis2.rmse
self.verify_compare(true_ext, true_simpl, transform=None, ties="average")

def get_simple_ts(self, samples=None, time=False, span=False, no_match=False):
def get_simple_ts(self, samples=None, time=False, span=False, no_match=False, extra_match=False):
# A simple tree sequence we can use to properly test various
# dissimilarity and MSRE values.
#
Expand Down Expand Up @@ -439,20 +429,19 @@ def test_with_no_match(self):
ts = self.get_simple_ts()
other = self.get_simple_ts(span=True, time=True, no_match=True)
self.verify_compare(ts, other)
self.verify_compare(ts, other, transform=lambda t: np.sqrt(1 + t), ties=None)
self.verify_compare(ts, other, transform=lambda t: np.sqrt(1 + t), ties="average")
self.verify_compare(ts, other, transform=lambda t: np.sqrt(1 + t))

def test_dissimilarity_value(self):
ts = self.get_simple_ts()
other = self.get_simple_ts(span=True)
dis = tscompare.compare(ts, other, transform=None, ties=None)
dis = tscompare.compare(ts, other, transform=None)
assert np.isclose(dis.arf, 4 / 46)
assert np.isclose(dis.rmse, 0.0)

def test_rmse(self):
ts = self.get_simple_ts()
other = self.get_simple_ts(time=True)
dis = tscompare.compare(ts, other, transform=None, ties=None)
dis = tscompare.compare(ts, other, transform=None)
true_total_span = 46
assert dis.total_span[0] == true_total_span
assert dis.total_span[1] == true_total_span
Expand All @@ -475,7 +464,7 @@ def f(t):
def test_value_and_error(self):
ts = self.get_simple_ts()
other = self.get_simple_ts(span=True, time=True)
dis = tscompare.compare(ts, other, transform=None, ties=None)
dis = tscompare.compare(ts, other, transform=None)
true_total_spans = (46, 47)
assert dis.total_span == true_total_spans

Expand Down
4 changes: 3 additions & 1 deletion tscompare/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@ def f(t):
)
# Between each pair of nodes, find the maximum shared span
best_match = best_match_matrix.argmax(axis=1).A1
best_match_spans = shared_spans[np.arange(len(best_match)), best_match].reshape(-1)/np.bincount(best_match)[best_match].reshape(-1)
best_match_spans = shared_spans[np.arange(len(best_match)), best_match].reshape(
-1
) / np.bincount(best_match)[best_match].reshape(-1)
total_match_span = np.sum(best_match_spans)
ts_node_spans = node_spans(ts)
total_span_ts = np.sum(ts_node_spans)
Expand Down

0 comments on commit 001957d

Please sign in to comment.