From 8382bd92b03f81aaecfa840807af331a54709d7f Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Mon, 19 Aug 2024 14:16:41 +0900 Subject: [PATCH] Fix multiGPU case --- src/jmteb/__main__.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/jmteb/__main__.py b/src/jmteb/__main__.py index d195e49..38f929c 100644 --- a/src/jmteb/__main__.py +++ b/src/jmteb/__main__.py @@ -8,6 +8,7 @@ from jmteb.embedders import TextEmbedder from jmteb.evaluators import EmbeddingEvaluator +from jmteb.utils.dist import is_main_process from jmteb.utils.score_recorder import JsonScoreRecorder @@ -17,7 +18,8 @@ def main( save_dir: str | None = None, overwrite_cache: bool = False, ): - logger.info(f"Start evaluating the following tasks\n{list(evaluators.keys())}") + if is_main_process(): + logger.info(f"Start evaluating the following tasks\n{list(evaluators.keys())}") if save_dir: Path(save_dir).mkdir(parents=True, exist_ok=True) @@ -25,30 +27,30 @@ def main( score_recorder = JsonScoreRecorder(save_dir) for eval_name, evaluator in evaluators.items(): - logger.info(f"Evaluating {eval_name}") + if is_main_process(): + logger.info(f"Evaluating {eval_name}") cache_dir = None if save_dir is not None: cache_dir = Path(save_dir) / "cache" / eval_name metrics = evaluator(text_embedder, cache_dir=cache_dir, overwrite_cache=overwrite_cache) - if not metrics: - continue - score_recorder.record_task_scores( - scores=metrics, - dataset_name=eval_name, - task_name=evaluator.__class__.__name__.replace("Evaluator", ""), - ) - if getattr(evaluator, "log_predictions", False): - score_recorder.record_predictions( - metrics, eval_name, evaluator.__class__.__name__.replace("Evaluator", "") + if metrics is not None: + score_recorder.record_task_scores( + scores=metrics, + dataset_name=eval_name, + task_name=evaluator.__class__.__name__.replace("Evaluator", ""), ) + if getattr(evaluator, "log_predictions", False): + score_recorder.record_predictions( + metrics, eval_name, evaluator.__class__.__name__.replace("Evaluator", "") + ) - logger.info(f"Results for {eval_name}\n{json.dumps(metrics.as_dict(), indent=4, ensure_ascii=False)}") + logger.info(f"Results for {eval_name}\n{json.dumps(metrics.as_dict(), indent=4, ensure_ascii=False)}") - if save_dir: + if save_dir and is_main_process(): logger.info(f"Saving result summary to {Path(save_dir) / 'summary.json'}") - score_recorder.record_summary() + score_recorder.record_summary() if __name__ == "__main__":