diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 87f9a7a..4133e1d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,7 @@ jobs: run-tests: runs-on: ubuntu-latest env: - PYTHON_VERSION: "3.9" + PYTHON_VERSION: "3.10" NO_CACHE: ${{ github.event.inputs.no-cache || 'false' }} steps: - name: Checkout @@ -53,7 +53,7 @@ jobs: lint_check: runs-on: ubuntu-latest env: - PYTHON_VERSION: "3.9" + PYTHON_VERSION: "3.10" steps: - uses: actions/checkout@v3 diff --git a/src/jmteb/configs/tasks/mrtydi.jsonnet b/src/jmteb/configs/tasks/mrtydi.jsonnet index be7cc5b..db2bf9e 100644 --- a/src/jmteb/configs/tasks/mrtydi.jsonnet +++ b/src/jmteb/configs/tasks/mrtydi.jsonnet @@ -26,6 +26,7 @@ name: 'mrtydi-corpus', }, }, + "doc_chunk_size":10000 }, }, } diff --git a/src/jmteb/evaluators/reranking/evaluator.py b/src/jmteb/evaluators/reranking/evaluator.py index 8d268c5..f2e136b 100644 --- a/src/jmteb/evaluators/reranking/evaluator.py +++ b/src/jmteb/evaluators/reranking/evaluator.py @@ -10,6 +10,7 @@ import tqdm from loguru import logger from torch import Tensor +from torch import distributed as dist from jmteb.embedders.base import TextEmbedder from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults @@ -149,7 +150,13 @@ def _compute_metrics( results: dict[str, float] = {} with tqdm.tqdm(total=len(query_dataset), desc="Reranking docs") as pbar: - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + if dist.is_available(): + device = f"cuda:{dist.get_rank()}" + else: + device = "cuda" + else: + device = "cpu" reranked_docs_list = [] for i, item in enumerate(query_dataset): query_embedding = to_tensor(query_embeddings[i], device=device) diff --git a/src/jmteb/evaluators/retrieval/evaluator.py b/src/jmteb/evaluators/retrieval/evaluator.py index d5339f9..3d91633 100644 --- a/src/jmteb/evaluators/retrieval/evaluator.py +++ b/src/jmteb/evaluators/retrieval/evaluator.py @@ -11,6 +11,7 @@ import tqdm from loguru import logger from torch import Tensor +from torch import distributed as dist from jmteb.embedders.base import TextEmbedder from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults @@ -158,7 +159,14 @@ def _compute_metrics( for offset in range(0, len(doc_embeddings), self.doc_chunk_size): doc_embeddings_chunk = doc_embeddings[offset : offset + self.doc_chunk_size] - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + if dist.is_available(): + device = f"cuda:{dist.get_rank()}" + else: + device = "cuda" + else: + device = "cpu" + query_embeddings = to_tensor(query_embeddings, device=device) doc_embeddings_chunk = to_tensor(doc_embeddings_chunk, device=device) similarity = dist_func(query_embeddings, doc_embeddings_chunk)