From b472b4f87f1f5d940da41ec8849b3e3e26ef8824 Mon Sep 17 00:00:00 2001 From: ancestor-mithril <sgeorge.sstoica99@gmail.com> Date: Mon, 29 Jul 2024 11:30:21 +0300 Subject: [PATCH] Added tqdm to track progress for multiple files --- dice_score_3d/metrics.py | 28 ++++++++++++++++++---------- pyproject.toml | 1 + 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/dice_score_3d/metrics.py b/dice_score_3d/metrics.py index 1a9b9df..a29ca5a 100644 --- a/dice_score_3d/metrics.py +++ b/dice_score_3d/metrics.py @@ -1,11 +1,13 @@ import json import os.path -from concurrent.futures import ProcessPoolExecutor from typing import List, Sequence, Tuple, Union import numpy as np -from dice_score_3d.reader import read_mask from numpy import ndarray +from tqdm import tqdm +from tqdm.contrib.concurrent import process_map + +from dice_score_3d.reader import read_mask def dice_metrics(ground_truths: str, predictions: str, output_path: Union[str, None], indices: dict, @@ -118,22 +120,28 @@ def evaluate_prediction(gt: str, pred: str, reorient: bool, dtype: np.dtype, ind return multi_class_dice(gt, pred, indices) +def execute_evaluate_predictions(gt_files: List[str], pred_files: List[str], reorient: bool, dtype: np.dtype, + indices: Sequence[int], num_workers: int) \ + -> Sequence[Tuple[ndarray, ndarray, ndarray, ndarray]]: + if num_workers == 0: + ret = [evaluate_prediction(gt, pred, reorient, dtype, indices) for gt, pred in tqdm(zip(gt_files, pred_files))] + else: + ret = process_map(evaluate_prediction, + [(gt, pred, reorient, dtype, indices) for gt, pred in zip(gt_files, pred_files)], + max_workers=num_workers) + return ret + + def evaluate_predictions(gt_files: List[str], pred_files: List[str], reorient: bool, dtype: np.dtype, indices: Sequence[int], num_workers: int) -> Tuple[ndarray, ndarray, ndarray, ndarray]: """ Evaluates each pair of prediction and GT sequentially or in parallel and collects metrics. """ - if num_workers == 0: - ret = [evaluate_prediction(gt, pred, reorient, dtype, indices) for gt, pred in zip(gt_files, pred_files)] - else: - with ProcessPoolExecutor(max_workers=num_workers) as executor: - ret = executor.map(evaluate_prediction, - [(gt, pred, reorient, dtype, indices) for gt, pred in zip(gt_files, pred_files)]) - + scores = execute_evaluate_predictions(gt_files, pred_files, reorient, dtype, indices, num_workers) common_voxels = [] all_voxels = [] gt_voxels = [] dice_scores = [] - for a, b, c, d in ret: + for a, b, c, d in scores: common_voxels.append(a) all_voxels.append(b) gt_voxels.append(c) diff --git a/pyproject.toml b/pyproject.toml index de1e32d..dc2f93c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ keywords = [ dependencies = [ "numpy", "SimpleITK", + "tqdm", ] [project.urls]