From c2c7f6937f8754e071008b3d5f8e8fcb4eb982fc Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Wed, 31 Jul 2024 16:08:58 +0900 Subject: [PATCH] Fix test cases --- tests/embedders/test_transformers.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/embedders/test_transformers.py b/tests/embedders/test_transformers.py index 0ab4943..0e1eed3 100644 --- a/tests/embedders/test_transformers.py +++ b/tests/embedders/test_transformers.py @@ -1,4 +1,3 @@ -import numpy as np import torch from jmteb.embedders.transformers_embedder import TransformersEmbedder @@ -13,12 +12,12 @@ def setup_class(cls): def test_encode(self): embeddings = self.model.encode("任意のテキスト") - assert isinstance(embeddings, np.ndarray) + assert isinstance(embeddings, torch.Tensor) assert embeddings.shape == (OUTPUT_DIM,) def test_encode_list(self): embeddings = self.model.encode(["任意のテキスト", "hello world", "埋め込み"]) - assert isinstance(embeddings, np.ndarray) + assert isinstance(embeddings, torch.Tensor) assert embeddings.shape == (3, OUTPUT_DIM) def test_get_output_dim(self): @@ -31,11 +30,9 @@ def test_tokenizer_kwargs(self): def test_model_kwargs(self): model = TransformersEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.float16}) - assert model.convert_to_tensor assert model.encode("任意のテキスト").dtype is torch.float16 def test_bf16(self): # As numpy doesn't support native bfloat16, add a test case for bf16 model = TransformersEmbedder(MODEL_NAME_OR_PATH, model_kwargs={"torch_dtype": torch.bfloat16}) - assert model.convert_to_tensor assert model.encode("任意のテキスト").dtype is torch.bfloat16