Skip to content

Commit

Permalink
end-2-end faster tests (#431)
Browse files Browse the repository at this point in the history
* end-2-end faster tests

* add torch vision
  • Loading branch information
michaelfeil authored Oct 19, 2024
1 parent 0a80dff commit 6df448f
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 30 deletions.
2 changes: 2 additions & 0 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def Field(*args, **kwargs): # type: ignore
def conlist(): # type: ignore
pass

DataURIorURL = None # type: ignore


class _Usage(BaseModel):
prompt_tokens: int
Expand Down
8 changes: 5 additions & 3 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import uuid
from contextlib import asynccontextmanager
from typing import Any, Optional, Union
from typing import Any, Optional, Union, TYPE_CHECKING

import infinity_emb
from infinity_emb._optional_imports import CHECK_TYPER, CHECK_UVICORN
Expand All @@ -21,7 +21,6 @@
AudioEmbeddingInput,
ClassifyInput,
ClassifyResult,
DataURIorURL,
ImageEmbeddingInput,
MultiModalOpenAIEmbedding,
OpenAIEmbeddingResult,
Expand All @@ -44,6 +43,9 @@
)
from infinity_emb.telemetry import PostHog, StartupTelemetry, telemetry_log_info

if TYPE_CHECKING:
from infinity_emb.fastapi_schemas.pymodels import DataURIorURL


def create_server(
*,
Expand Down Expand Up @@ -232,7 +234,7 @@ def _resolve_engine(model: str) -> "AsyncEmbeddingEngine":
return engine

def _resolve_mixed_input(
inputs: Union[DataURIorURL, list[DataURIorURL]],
inputs: Union["DataURIorURL", list["DataURIorURL"]],
) -> list[Union[str, bytes]]:
if hasattr(inputs, "host"):
# if it is a single url
Expand Down
6 changes: 3 additions & 3 deletions libs/infinity_emb/tests/end_to_end/test_ct2_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def client():
yield client


@pytest.mark.skipif(sys.platform == "darwin", reason="Does not run on macOS")
@pytest.mark.skipif(sys.platform == "darwin", reason="CTranslate2 Does not run on macOS")
def test_load_model(model_base):
# this makes sure that the error below is not based on a slow download
# or internal pytorch errors
Expand All @@ -58,7 +58,7 @@ def test_load_model(model_base):


@pytest.mark.anyio
@pytest.mark.skipif(sys.platform == "darwin", reason="Does not run on macOS")
@pytest.mark.skipif(sys.platform == "darwin", reason="CTranslate2 Does not run on macOS")
async def test_model_route(client):
response = await client.get(f"{PREFIX}/models")
assert response.status_code == 200
Expand All @@ -69,7 +69,7 @@ async def test_model_route(client):


@pytest.mark.anyio
@pytest.mark.skipif(sys.platform == "darwin", reason="Does not run on macOS")
@pytest.mark.skipif(sys.platform == "darwin", reason="CTranslate2 Does not run on macOS")
async def test_embedding(client, model_base, helpers):
await helpers.embedding_verify(client, model_base, prefix=PREFIX, model_name=MODEL)

Expand Down
9 changes: 0 additions & 9 deletions libs/infinity_emb/tests/end_to_end/test_torch_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,6 @@ async def test_audio_base64_fail(client):

assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


@pytest.mark.anyio
async def test_audio_fail(client):
for route in [f"{PREFIX}/embeddings_audio", f"{PREFIX}/embeddings"]:
audio_url = "https://www.google.com/404"

Expand All @@ -203,9 +200,6 @@ async def test_audio_fail(client):
)
assert response.status_code == status.HTTP_400_BAD_REQUEST


@pytest.mark.anyio
async def test_audio_empty(client):
audio_url_empty = []

response_empty = await client.post(
Expand All @@ -218,9 +212,6 @@ async def test_audio_empty(client):
)
assert response_empty.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


@pytest.mark.anyio
async def test_unsupported_endpoints(client):
response_unsupported = await client.post(
f"{PREFIX}/classify",
json={"model": MODEL, "input": ["test"]},
Expand Down
9 changes: 0 additions & 9 deletions libs/infinity_emb/tests/end_to_end/test_torch_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,3 @@ async def test_classifier(client, model_base):
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
# rdata_results = rdata["results"]

# predictions = [
# model_base.predict({"text": query, "text_pair": doc}) for doc in documents
# ]

# assert len(rdata_results) == len(predictions)
# for i, pred in enumerate(predictions):
# assert abs(rdata_results[i]["relevance_score"] - pred["score"]) < 0.01
6 changes: 0 additions & 6 deletions libs/infinity_emb/tests/end_to_end/test_torch_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,19 +206,13 @@ async def test_vision_fail(client):
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


@pytest.mark.anyio
async def test_vision_empty(client):
image_url_empty = []
response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_url_empty},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


@pytest.mark.anyio
async def test_unsupported_endpoints(client):
response_unsupported = await client.post(
f"{PREFIX}/classify",
json={"model": MODEL, "input": ["test"]},
Expand Down

0 comments on commit 6df448f

Please sign in to comment.