Skip to content

Commit

Permalink
Fix test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
lsz05 committed Jul 31, 2024
1 parent 5127941 commit c2c7f69
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions tests/embedders/test_transformers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import torch

from jmteb.embedders.transformers_embedder import TransformersEmbedder
Expand All @@ -13,12 +12,12 @@ def setup_class(cls):

def test_encode(self):
embeddings = self.model.encode("任意のテキスト")
assert isinstance(embeddings, np.ndarray)
assert isinstance(embeddings, torch.Tensor)
assert embeddings.shape == (OUTPUT_DIM,)

def test_encode_list(self):
embeddings = self.model.encode(["任意のテキスト", "hello world", "埋め込み"])
assert isinstance(embeddings, np.ndarray)
assert isinstance(embeddings, torch.Tensor)
assert embeddings.shape == (3, OUTPUT_DIM)

def test_get_output_dim(self):
Expand All @@ -31,11 +30,9 @@ def test_tokenizer_kwargs(self):

def test_model_kwargs(self):
model = TransformersEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.float16})
assert model.convert_to_tensor
assert model.encode("任意のテキスト").dtype is torch.float16

def test_bf16(self):
# As numpy doesn't support native bfloat16, add a test case for bf16
model = TransformersEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.bfloat16})
assert model.convert_to_tensor
assert model.encode("任意のテキスト").dtype is torch.bfloat16

0 comments on commit c2c7f69

Please sign in to comment.