Skip to content

Commit

Permalink
add simpler modality arg (#401)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil authored Oct 7, 2024
1 parent 99339a9 commit e139fcf
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 27 deletions.
2 changes: 1 addition & 1 deletion docs/assets/openapi.json

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class _OpenAIEmbeddingInput_Text(_OpenAIEmbeddingInput):
),
Annotated[str, INPUT_STRING],
]
infinity_extra_modality: Literal[Modality.text] = Modality.text # type: ignore
modality: Literal[Modality.text] = Modality.text # type: ignore


class _OpenAIEmbeddingInput_URI(_OpenAIEmbeddingInput):
Expand All @@ -115,21 +115,21 @@ class _OpenAIEmbeddingInput_URI(_OpenAIEmbeddingInput):


class OpenAIEmbeddingInput_Audio(_OpenAIEmbeddingInput_URI):
infinity_extra_modality: Literal[Modality.audio] = Modality.audio # type: ignore
modality: Literal[Modality.audio] = Modality.audio # type: ignore


class OpenAIEmbeddingInput_Image(_OpenAIEmbeddingInput_URI):
infinity_extra_modality: Literal[Modality.image] = Modality.image # type: ignore
modality: Literal[Modality.image] = Modality.image # type: ignore


def get_infinity_extra_modality(obj: dict) -> str:
def get_modality(obj: dict) -> str:
"""resolve the modality of the extra_body.
If not present, default to text
Function name is used to return error message, keep it explicit
"""
try:
return obj.get("infinity_extra_modality", Modality.text.value)
return obj.get("modality", Modality.text.value)
except AttributeError:
# in case a very weird request is sent, validate it against the default
return Modality.text.value
Expand All @@ -142,7 +142,7 @@ class MultiModalOpenAIEmbedding(RootModel):
Annotated[OpenAIEmbeddingInput_Audio, Tag(Modality.audio.value)],
Annotated[OpenAIEmbeddingInput_Image, Tag(Modality.image.value)],
],
Discriminator(get_infinity_extra_modality),
Discriminator(get_modality),
]


Expand Down
14 changes: 7 additions & 7 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def _embeddings(data: MultiModalOpenAIEmbedding):
# can also be base64 encoded
],
# set extra modality to image to process as image
"infinity_extra_modality": "image"
"modality": "image"
)
```
Expand All @@ -271,7 +271,7 @@ def url_to_base64(url, modality = "image"):
url, url_to_base64(url, "audio")
],
# set extra modality to audio to process as audio
"infinity_extra_modality": "audio"
"modality": "audio"
}
)
```
Expand All @@ -285,7 +285,7 @@ def url_to_base64(url, modality = "image"):
input=[url_to_base64(url, "audio")],
encoding_format= "base64",
extra_body={
"infinity_extra_modality": "audio"
"modality": "audio"
}
)
Expand All @@ -294,7 +294,7 @@ def url_to_base64(url, modality = "image"):
input=["the sound of a beep", "the sound of a cat"],
encoding_format= "base64",
extra_body={
"infinity_extra_modality": "text"
"modality": "text"
}
)
```
Expand All @@ -305,7 +305,7 @@ def url_to_base64(url, modality = "image"):
```
"""

modality = data.root.infinity_extra_modality
modality = data.root.modality
data_root = data.root
engine = _resolve_engine(data_root.model)

Expand Down Expand Up @@ -471,7 +471,7 @@ async def _classify(data: ClassifyInput):
dependencies=route_dependencies,
operation_id="embeddings_image",
deprecated=True,
summary="Deprecated: Use `embeddings` with `infinity_extra_modality` set to `image`",
summary="Deprecated: Use `embeddings` with `modality` set to `image`",
)
async def _embeddings_image(data: ImageEmbeddingInput):
"""Encode Embeddings from Image files
Expand Down Expand Up @@ -530,7 +530,7 @@ async def _embeddings_image(data: ImageEmbeddingInput):
dependencies=route_dependencies,
operation_id="embeddings_audio",
deprecated=True,
summary="Deprecated: Use `embeddings` with `infinity_extra_modality` set to `audio`",
summary="Deprecated: Use `embeddings` with `modality` set to `audio`",
)
async def _embeddings_audio(data: AudioEmbeddingInput):
"""Encode Embeddings from Audio files
Expand Down
16 changes: 8 additions & 8 deletions libs/infinity_emb/tests/end_to_end/test_openapi_client_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,46 +79,46 @@ async def test_openai(client: AsyncClient):
"the sound of a bird",
],
encoding_format="float",
extra_body={"infinity_extra_modality": "text"},
extra_body={"modality": "text"},
)
emb1_audio = await client_oai.embeddings.create(
model=pytest.DEFAULT_AUDIO_MODEL,
input=[url_to_base64(pytest.AUDIO_SAMPLE_URL, "audio")],
encoding_format="float",
extra_body={"infinity_extra_modality": "audio"},
extra_body={"modality": "audio"},
)
emb1_1_audio = await client_oai.embeddings.create(
model=pytest.DEFAULT_AUDIO_MODEL,
input=[pytest.AUDIO_SAMPLE_URL],
encoding_format="float",
extra_body={"infinity_extra_modality": "audio"},
extra_body={"modality": "audio"},
)
# test: image
emb_1_image_from_text = await client_oai.embeddings.create(
model=pytest.DEFAULT_IMAGE_MODEL,
input=["a cat", "a dog", "a bird"],
encoding_format="float",
extra_body={"infinity_extra_modality": "text"},
extra_body={"modality": "text"},
)
emb_1_image = await client_oai.embeddings.create(
model=pytest.DEFAULT_IMAGE_MODEL,
input=[url_to_base64(pytest.IMAGE_SAMPLE_URL, "image")], # image is a cat
encoding_format="float",
extra_body={"infinity_extra_modality": "image"},
extra_body={"modality": "image"},
)
emb_1_1_image = await client_oai.embeddings.create(
model=pytest.DEFAULT_IMAGE_MODEL,
input=[pytest.IMAGE_SAMPLE_URL],
encoding_format="float",
extra_body={"infinity_extra_modality": "image"},
extra_body={"modality": "image"},
)

# test: text
emb_1_text = await client_oai.embeddings.create(
model=pytest.DEFAULT_BERT_MODEL,
input=["a cat", "a cat", "a bird"],
encoding_format="float",
extra_body={"infinity_extra_modality": "text"},
extra_body={"modality": "text"},
)

# test AUDIO: cosine distance of beep to cat and dog
Expand Down Expand Up @@ -156,5 +156,5 @@ async def test_openai(client: AsyncClient):
model=pytest.DEFAULT_AUDIO_MODEL,
input=[pytest.AUDIO_SAMPLE_URL],
encoding_format="float",
extra_body={"infinity_extra_modality": "audio"},
extra_body={"modality": "audio"},
)
6 changes: 3 additions & 3 deletions libs/infinity_emb/tests/end_to_end/test_torch_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def test_audio_multiple(client):
json={
"model": MODEL,
"input": audio_urls,
"infinity_extra_modality": "audio",
"modality": "audio",
},
)
assert response.status_code == 200
Expand All @@ -151,7 +151,7 @@ async def test_audio_fail(client):
json={
"model": MODEL,
"input": audio_url,
"infinity_extra_modality": "audio",
"modality": "audio",
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
Expand All @@ -166,7 +166,7 @@ async def test_audio_empty(client):
json={
"model": MODEL,
"input": audio_url_empty,
"infinity_extra_modality": "audio",
"modality": "audio",
},
)
assert response_empty.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
Expand Down
4 changes: 2 additions & 2 deletions libs/infinity_emb/tests/end_to_end/test_torch_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def test_meta(client, helpers):
json={
"model": MODEL,
"input": image_input,
"infinity_extra_modality": "image",
"modality": "image",
},
)

Expand Down Expand Up @@ -166,7 +166,7 @@ async def test_vision_multiple(client):
json={
"model": MODEL,
"input": image_urls,
"infinity_extra_modality": "image",
"modality": "image",
},
)
assert response.status_code == 200
Expand Down

0 comments on commit e139fcf

Please sign in to comment.