Skip to content

Commit

Permalink
Merge branch 'dev' into feature/apply_devset
Browse files Browse the repository at this point in the history
  • Loading branch information
lsz05 committed Apr 16, 2024
2 parents c7f4132 + 26a5979 commit 54b51d5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/jmteb/embedders/openai_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, model: str = "text-embedding-3-small", dim: int | None = None
self.client = OpenAI() # API key written in .env
assert model in MODEL_DIM.keys(), f"`model` must be one of {list(MODEL_DIM.keys())}!"
self.model = model
if not dim:
if not dim or model == "text-embedding-ada-002":
self.dim = MODEL_DIM[self.model]
else:
if dim > MODEL_DIM[self.model]:
Expand All @@ -43,13 +43,15 @@ def __init__(self, model: str = "text-embedding-3-small", dim: int | None = None
self.dim = dim

def encode(self, text: str | list[str]) -> np.ndarray:
kwargs = {"dimensions": self.dim} if self.model != "text-embedding-ada-002" else {}
# specifying `dimensions` is not allowed for "text-embedding-ada-002"
result = np.asarray(
[
data.embedding
for data in self.client.embeddings.create(
input=text,
model=self.model,
dimensions=self.dim,
**kwargs,
).data
]
)
Expand Down
20 changes: 19 additions & 1 deletion tests/embedders/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ class MockEmbedding:


class MockOpenAIClientEmbedding:
def create(input: str | list[str], model: str, dimensions: int):
def create(input: str | list[str], model: str, **kwargs):
if model == "text-embedding-ada-002":
assert "dimensions" not in kwargs
dimensions = OUTPUT_DIM
else:
assert "dimensions" in kwargs
dimensions = kwargs.get("dimensions")
if isinstance(input, str):
input = [input]
return MockData(data=[MockEmbedding(embedding=[0.1] * dimensions)] * len(input))
Expand Down Expand Up @@ -62,6 +68,18 @@ def test_model_dim(self):
assert OpenAIEmbedder(model="text-embedding-3-large").dim == 3072
assert OpenAIEmbedder(model="text-embedding-ada-002").dim == 1536

def test_ada_002_dim(self):
# check that no `dimensions` argument is set for model "text-embedding-ada-002"
# else an assertion error will be raised in MockOpenAIClientEmbedding
# and model "text-embedding-ada-002" has a fixed output dimension
embeddings = OpenAIEmbedder(model="text-embedding-ada-002", dim=2 * OUTPUT_DIM).encode("任意のテキスト")
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (OUTPUT_DIM,)

embeddings = OpenAIEmbedder(model="text-embedding-ada-002", dim=OUTPUT_DIM // 2).encode("任意のテキスト")
assert isinstance(embeddings, np.ndarray)
assert embeddings.shape == (OUTPUT_DIM,)

def test_dim_over_max(self):
assert OpenAIEmbedder(dim=2 * OUTPUT_DIM).dim == OUTPUT_DIM

Expand Down

0 comments on commit 54b51d5

Please sign in to comment.