Skip to content

Commit

Permalink
Fix multiGPU case
Browse files Browse the repository at this point in the history
  • Loading branch information
lsz05 committed Aug 19, 2024
1 parent 46c01e8 commit 8382bd9
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions src/jmteb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from jmteb.embedders import TextEmbedder
from jmteb.evaluators import EmbeddingEvaluator
from jmteb.utils.dist import is_main_process
from jmteb.utils.score_recorder import JsonScoreRecorder


Expand All @@ -17,38 +18,39 @@ def main(
save_dir: str | None = None,
overwrite_cache: bool = False,
):
logger.info(f"Start evaluating the following tasks\n{list(evaluators.keys())}")
if is_main_process():
logger.info(f"Start evaluating the following tasks\n{list(evaluators.keys())}")

if save_dir:
Path(save_dir).mkdir(parents=True, exist_ok=True)

score_recorder = JsonScoreRecorder(save_dir)

for eval_name, evaluator in evaluators.items():
logger.info(f"Evaluating {eval_name}")
if is_main_process():
logger.info(f"Evaluating {eval_name}")

cache_dir = None
if save_dir is not None:
cache_dir = Path(save_dir) / "cache" / eval_name

metrics = evaluator(text_embedder, cache_dir=cache_dir, overwrite_cache=overwrite_cache)
if not metrics:
continue
score_recorder.record_task_scores(
scores=metrics,
dataset_name=eval_name,
task_name=evaluator.__class__.__name__.replace("Evaluator", ""),
)
if getattr(evaluator, "log_predictions", False):
score_recorder.record_predictions(
metrics, eval_name, evaluator.__class__.__name__.replace("Evaluator", "")
if metrics is not None:
score_recorder.record_task_scores(
scores=metrics,
dataset_name=eval_name,
task_name=evaluator.__class__.__name__.replace("Evaluator", ""),
)
if getattr(evaluator, "log_predictions", False):
score_recorder.record_predictions(
metrics, eval_name, evaluator.__class__.__name__.replace("Evaluator", "")
)

logger.info(f"Results for {eval_name}\n{json.dumps(metrics.as_dict(), indent=4, ensure_ascii=False)}")
logger.info(f"Results for {eval_name}\n{json.dumps(metrics.as_dict(), indent=4, ensure_ascii=False)}")

if save_dir:
if save_dir and is_main_process():
logger.info(f"Saving result summary to {Path(save_dir) / 'summary.json'}")
score_recorder.record_summary()
score_recorder.record_summary()


if __name__ == "__main__":
Expand Down

0 comments on commit 8382bd9

Please sign in to comment.