From b472b4f87f1f5d940da41ec8849b3e3e26ef8824 Mon Sep 17 00:00:00 2001
From: ancestor-mithril <sgeorge.sstoica99@gmail.com>
Date: Mon, 29 Jul 2024 11:30:21 +0300
Subject: [PATCH] Added tqdm to track progress for multiple files

---
 dice_score_3d/metrics.py | 28 ++++++++++++++++++----------
 pyproject.toml           |  1 +
 2 files changed, 19 insertions(+), 10 deletions(-)

diff --git a/dice_score_3d/metrics.py b/dice_score_3d/metrics.py
index 1a9b9df..a29ca5a 100644
--- a/dice_score_3d/metrics.py
+++ b/dice_score_3d/metrics.py
@@ -1,11 +1,13 @@
 import json
 import os.path
-from concurrent.futures import ProcessPoolExecutor
 from typing import List, Sequence, Tuple, Union
 
 import numpy as np
-from dice_score_3d.reader import read_mask
 from numpy import ndarray
+from tqdm import tqdm
+from tqdm.contrib.concurrent import process_map
+
+from dice_score_3d.reader import read_mask
 
 
 def dice_metrics(ground_truths: str, predictions: str, output_path: Union[str, None], indices: dict,
@@ -118,22 +120,28 @@ def evaluate_prediction(gt: str, pred: str, reorient: bool, dtype: np.dtype, ind
     return multi_class_dice(gt, pred, indices)
 
 
+def execute_evaluate_predictions(gt_files: List[str], pred_files: List[str], reorient: bool, dtype: np.dtype,
+                                 indices: Sequence[int], num_workers: int) \
+        -> Sequence[Tuple[ndarray, ndarray, ndarray, ndarray]]:
+    if num_workers == 0:
+        ret = [evaluate_prediction(gt, pred, reorient, dtype, indices) for gt, pred in tqdm(zip(gt_files, pred_files))]
+    else:
+        ret = process_map(evaluate_prediction,
+                          [(gt, pred, reorient, dtype, indices) for gt, pred in zip(gt_files, pred_files)],
+                          max_workers=num_workers)
+    return ret
+
+
 def evaluate_predictions(gt_files: List[str], pred_files: List[str], reorient: bool, dtype: np.dtype,
                          indices: Sequence[int], num_workers: int) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
     """ Evaluates each pair of prediction and GT sequentially or in parallel and collects metrics.
     """
-    if num_workers == 0:
-        ret = [evaluate_prediction(gt, pred, reorient, dtype, indices) for gt, pred in zip(gt_files, pred_files)]
-    else:
-        with ProcessPoolExecutor(max_workers=num_workers) as executor:
-            ret = executor.map(evaluate_prediction,
-                               [(gt, pred, reorient, dtype, indices) for gt, pred in zip(gt_files, pred_files)])
-
+    scores = execute_evaluate_predictions(gt_files, pred_files, reorient, dtype, indices, num_workers)
     common_voxels = []
     all_voxels = []
     gt_voxels = []
     dice_scores = []
-    for a, b, c, d in ret:
+    for a, b, c, d in scores:
         common_voxels.append(a)
         all_voxels.append(b)
         gt_voxels.append(c)
diff --git a/pyproject.toml b/pyproject.toml
index de1e32d..dc2f93c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,6 +32,7 @@ keywords = [
 dependencies = [
     "numpy",
     "SimpleITK",
+    "tqdm",
 ]
 
 [project.urls]