Skip to content

Commit

Permalink
Merge pull request #51 from sbintuitions/fix/multi_gpu_rerank_retrieval
Browse files Browse the repository at this point in the history
[Fix] multi-GPU rerank retrievalタスクにおいてOOM防止のための修正
  • Loading branch information
lsz05 authored Aug 5, 2024
2 parents 7881bd7 + 683fbd8 commit 0810589
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
run-tests:
runs-on: ubuntu-latest
env:
PYTHON_VERSION: "3.9"
PYTHON_VERSION: "3.10"
NO_CACHE: ${{ github.event.inputs.no-cache || 'false' }}
steps:
- name: Checkout
Expand Down Expand Up @@ -53,7 +53,7 @@ jobs:
lint_check:
runs-on: ubuntu-latest
env:
PYTHON_VERSION: "3.9"
PYTHON_VERSION: "3.10"
steps:
- uses: actions/checkout@v3

Expand Down
1 change: 1 addition & 0 deletions src/jmteb/configs/tasks/mrtydi.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
name: 'mrtydi-corpus',
},
},
"doc_chunk_size":10000
},
},
}
9 changes: 8 additions & 1 deletion src/jmteb/evaluators/reranking/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import tqdm
from loguru import logger
from torch import Tensor
from torch import distributed as dist

from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults
Expand Down Expand Up @@ -149,7 +150,13 @@ def _compute_metrics(
results: dict[str, float] = {}

with tqdm.tqdm(total=len(query_dataset), desc="Reranking docs") as pbar:
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
if dist.is_available():
device = f"cuda:{dist.get_rank()}"
else:
device = "cuda"
else:
device = "cpu"
reranked_docs_list = []
for i, item in enumerate(query_dataset):
query_embedding = to_tensor(query_embeddings[i], device=device)
Expand Down
10 changes: 9 additions & 1 deletion src/jmteb/evaluators/retrieval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tqdm
from loguru import logger
from torch import Tensor
from torch import distributed as dist

from jmteb.embedders.base import TextEmbedder
from jmteb.evaluators.base import EmbeddingEvaluator, EvaluationResults
Expand Down Expand Up @@ -158,7 +159,14 @@ def _compute_metrics(
for offset in range(0, len(doc_embeddings), self.doc_chunk_size):
doc_embeddings_chunk = doc_embeddings[offset : offset + self.doc_chunk_size]

device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
if dist.is_available():
device = f"cuda:{dist.get_rank()}"
else:
device = "cuda"
else:
device = "cpu"

query_embeddings = to_tensor(query_embeddings, device=device)
doc_embeddings_chunk = to_tensor(doc_embeddings_chunk, device=device)
similarity = dist_func(query_embeddings, doc_embeddings_chunk)
Expand Down

0 comments on commit 0810589

Please sign in to comment.