Skip to content

Commit

Permalink
Added tests to catch dice_metrics asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Jul 12, 2024
1 parent 2bfe758 commit 88c9195
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion dice_score_3d/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def dice_metrics(ground_truths: str, predictions: str, output_path: str, indices
dtype = np.uint8 if dtype == 'uint8' else np.uint16
assert os.path.isfile(ground_truths) and os.path.isfile(predictions) or \
os.path.isdir(ground_truths) and os.path.isdir(predictions), ('Prediction path and GT path must both be a '
'a single file or a folder.')
'single file or a folder.')

if os.path.isdir(ground_truths):
gt_files = sorted([x for x in os.listdir(ground_truths) if x.startswith(prefix) and x.endswith(suffix)])
Expand Down
10 changes: 9 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import unittest

import numpy as np
from dice_score_3d.reader import read_mask

from dice_score_3d import dice_metrics
from dice_score_3d.metrics import dice, multi_class_dice, evaluate_prediction
from tests.utils import create_and_write_volume

Expand Down Expand Up @@ -43,6 +43,14 @@ def test_evaluate_prediction(self):
tmp.close()
os.unlink(tmp.name)

def test_dice_metrics(self):
self.assertRaisesRegex(AssertionError, 'Prediction path and GT path must both be a single file or a folder',
dice_metrics, './', './random_string?.!@3$not_a_path', 'results.csv', {'Lung': 1})
self.assertRaisesRegex(AssertionError, 'Output path must be either .csv or .json, is results.txt',
dice_metrics, './', './', 'results.txt', {'Lung': 1})
self.assertRaisesRegex(AssertionError, 'Indices must be integers, found .*',
dice_metrics, './', './', 'results.csv', {'Lung': 'text'})


if __name__ == '__main__':
unittest.main()

0 comments on commit 88c9195

Please sign in to comment.