Skip to content

Commit

Permalink
not to use distributed encoding when the number of samples is less th…
Browse files Browse the repository at this point in the history
…an the number of processes in order to prevent bug
  • Loading branch information
lsz05 committed Jul 31, 2024
1 parent c2c7f69 commit 41ec3e3
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions src/jmteb/embedders/transformers_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def encode(
prefix: str | None = None,
dtype: Literal["float32", "float16", "bfloat16"] | None = None,
) -> torch.Tensor:
if self.distributed_state:
if self.distributed_state and len(text) >= self.distributed_state.num_processes:
embeddings = self._encode_distributed(text, prefix)
else:
embeddings = self._encode(text, prefix)
Expand Down Expand Up @@ -212,7 +212,6 @@ def _encode_distributed(self, text: list[str], prefix: str | None = None) -> tor
with self.distributed_state.split_between_processes(text) as t:
sentence_embeddings = self._encode(t, prefix)
batch_gather.extend(torch.Tensor(sentence_embeddings).to("cpu"))

batch_embeddings = gather_object(batch_gather)
return torch.stack(batch_embeddings)

Expand Down

0 comments on commit 41ec3e3

Please sign in to comment.