Skip to content

Commit

Permalink
Fix a bug where search and embed return 500 errors when having no acc…
Browse files Browse the repository at this point in the history
…ess to a private image
  • Loading branch information
wanliAlex authored and vicilliar committed Nov 28, 2024
1 parent f382844 commit 3313327
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 41 deletions.
9 changes: 4 additions & 5 deletions src/marqo/api/models/embed_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
Choices (enum-type structure) in fastAPI:
https://pydantic-docs.helpmanual.io/usage/types/#enums-and-choices
"""
import pydantic
from typing import Union, List, Dict, Optional, Any
from typing import Union, List, Dict, Optional

import pydantic
from pydantic import Field, root_validator

from marqo.tensor_search.models.private_models import ModelAuth
from marqo.base_model import MarqoBaseModel
from marqo.core.embed.embed import EmbedContentType

from marqo.tensor_search.models.private_models import ModelAuth


class EmbedRequest(MarqoBaseModel):
# content can be a single query or list of queries. Queries can be a string or a dictionary.
content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]]
image_download_headers: Optional[Dict] = Field(default=None, alias="imageDownloadHeaders")
imageDownloadHeaders: Optional[Dict] = Field(default=None, alias="image_download_headers")
mediaDownloadHeaders: Optional[Dict] = None
modelAuth: Optional[ModelAuth] = None
content_type: Optional[EmbedContentType] = Field(default=EmbedContentType.Query, alias="contentType")
Expand Down
11 changes: 7 additions & 4 deletions src/marqo/core/embed/embed.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from enum import Enum
from timeit import default_timer as timer
from typing import List, Optional, Union, Dict
from enum import Enum

import pydantic

import marqo.api.exceptions as api_exceptions
import marqo.s2_inference.errors as s2_inference_errors
from marqo import exceptions as base_exceptions
from marqo.core.index_management.index_management import IndexManagement
from marqo.tensor_search import utils
from marqo.tensor_search.models.api_models import BulkSearchQueryEntity
from marqo.tensor_search.models.private_models import ModelAuth
from marqo.tensor_search.models.search import Qidx
from marqo.tensor_search.telemetry import RequestMetricsStore
from marqo.tensor_search.tensor_search_logging import get_logger
from marqo.core.utils.prefix import determine_text_prefix, DeterminePrefixContentType
from marqo.vespa.vespa_client import VespaClient
from marqo.tensor_search import utils

logger = get_logger(__name__)

Expand Down Expand Up @@ -116,7 +117,9 @@ def embed_content(

# Vectorise the queries
with RequestMetricsStore.for_request().time(f"embed.vector_inference_full_pipeline"):
qidx_to_vectors: Dict[Qidx, List[float]] = tensor_search.run_vectorise_pipeline(temp_config, queries, device)
qidx_to_vectors: Dict[Qidx, List[float]] = tensor_search.run_vectorise_pipeline(
temp_config, queries, device
)

embeddings: List[List[float]] = list(qidx_to_vectors.values())

Expand Down
17 changes: 15 additions & 2 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2018,6 +2018,12 @@ def get_content_vector(


def add_prefix_to_queries(queries: List[BulkSearchQueryEntity]) -> List[BulkSearchQueryEntity]:
"""
Add prefix to the queries if it is a text query.
Raises:
MediaDownloadError: If the media cannot be downloaded
"""
prefixed_queries = []
for q in queries:
text_query_prefix = q.index.model.get_text_query_prefix(q.text_query_prefix)
Expand Down Expand Up @@ -2064,10 +2070,17 @@ def add_prefix_to_queries(queries: List[BulkSearchQueryEntity]) -> List[BulkSear

def run_vectorise_pipeline(config: Config, queries: List[BulkSearchQueryEntity], device: Union[Device, str]) -> Dict[
Qidx, List[float]]:
"""Run the query vectorisation process"""
"""Run the query vectorisation process
Raise:
api_exceptions.InvalidArgError: If the vectorisation process fails or if the media cannot be downloaded.
"""

# Prepend the prefixes to the queries if it exists (output should be of type List[BulkSearchQueryEntity])
prefixed_queries = add_prefix_to_queries(queries)
try:
prefixed_queries = add_prefix_to_queries(queries)
except s2_inference_errors.MediaDownloadError as e:
raise api_exceptions.InvalidArgError(message=str(e)) from e

# 1. Pre-process inputs ready for s2_inference.vectorise
# we can still use qidx_to_job. But the jobs structure may need to be different
Expand Down
2 changes: 1 addition & 1 deletion src/marqo/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.13.0"
__version__ = "2.13.1"

def get_version() -> str:
return f"{__version__}"
Original file line number Diff line number Diff line change
Expand Up @@ -1129,8 +1129,7 @@ def test_add_private_images_proper_error_returned(self):

def test_add_private_images_success(self):
"""Test to ensure that private images can be downloaded with proper headers"""
# test_indexes = [self.structured_marqo_index_name, self.unstructured_marqo_index_name]
test_indexes = [self.unstructured_marqo_index_name, ]
test_indexes = [self.structured_marqo_index_name, self.unstructured_marqo_index_name]
documents = [
{
"image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
Expand Down
86 changes: 60 additions & 26 deletions tests/tensor_search/integ_tests/test_embed.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,25 @@
from marqo.core.models.add_docs_params import AddDocsParams
from marqo.tensor_search import tensor_search
from marqo.tensor_search import enums
from marqo.tensor_search.models.api_models import BulkSearchQuery, BulkSearchQueryEntity, ScoreModifierLists
from tests.marqo_test import MarqoTestCase, TestImageUrls
from marqo.tensor_search.tensor_search import add_documents
from marqo.tensor_search.models.search import SearchContext
from marqo.tensor_search.api import embed
import os
import unittest
import uuid
from unittest import mock
from unittest.mock import patch

import numpy as np
import torch
import pytest
import torch

import requests
import json
from unittest import mock
from unittest.mock import patch
from marqo.core.models.marqo_index_request import FieldRequest
from marqo.api.exceptions import MarqoWebError, IndexNotFoundError, InvalidArgError, DocumentNotFoundError
import marqo.exceptions as base_exceptions
from marqo.api.models.embed_request import EmbedRequest
from marqo.core.models.add_docs_params import AddDocsParams
from marqo.core.models.marqo_index import *
from marqo.vespa.models import VespaDocument, QueryResult, FeedBatchDocumentResponse, FeedBatchResponse, \
FeedDocumentResponse
from marqo.core.models.marqo_index_request import FieldRequest
from marqo.tensor_search import enums
from marqo.tensor_search import tensor_search
from marqo.tensor_search.api import embed
from marqo.tensor_search.models.private_models import S3Auth, ModelAuth
from marqo.vespa.models import QueryResult
from marqo.vespa.models.query_result import Root, Child, RootFields
from marqo.tensor_search.models.private_models import S3Auth, ModelAuth, HfAuth
from marqo.api.models.embed_request import EmbedRequest
from marqo.tensor_search import utils
import os
import pprint
import unittest
import httpx
import uuid
from tests.marqo_test import MarqoTestCase, TestImageUrls
from marqo.api.exceptions import InvalidArgError


class TestEmbed(MarqoTestCase):
Expand Down Expand Up @@ -756,3 +747,46 @@ def test_embed_prefix_content_type(self):
# Assert vectors are equal
self.assertEqual(embed_res_hardcoded["content"], ["test passage: I am the GOAT."])
self.assertTrue(np.allclose(embed_res_hardcoded["embeddings"][0], embed_res_prefix_query["embeddings"][0]))

def test_embed_private_image_proper_error_raised(self):
"""Test that a proper 400 error is raised when trying to embed a private image and have no access."""
test_content_lists = [
("https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small", "a single private image url"),
(["https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small", "test"],
"a list of content with a private image url")
]

for index_name in [self.unstructured_default_image_index.name, self.structured_default_image_index.name]:
for test_content, msg in test_content_lists:
with self.subTest(f"{index_name} - {msg}"):
with self.assertRaises(InvalidArgError) as e:
embed_res = embed(
marqo_config=self.config, index_name=index_name,
embedding_request=EmbedRequest(
content=test_content
),
device="cpu"
)
self.assertIn("Error downloading media file", str(e.exception))
self.assertIn("403 Client Error", str(e.exception))

def test_embed_invalid_image_proper_error_raised(self):
"""Test that a proper 400 error is raised when trying to embed an invalid image url."""
test_content_lists = [
("https://a-dummy-image-url.jpg", "a single invalid image url"),
(["https://a-dummy-image-url.jpg", "test"],
"a list of content with an invalid image url")
]

for index_name in [self.unstructured_default_image_index.name, self.structured_default_image_index.name]:
for test_content, msg in test_content_lists:
with self.subTest(f"{index_name} - {msg}"):
with self.assertRaises(InvalidArgError) as e:
embed_res = embed(
marqo_config=self.config, index_name=index_name,
embedding_request=EmbedRequest(
content=test_content
),
device="cpu"
)
self.assertIn("Error vectorising content", str(e.exception))
37 changes: 36 additions & 1 deletion tests/tensor_search/integ_tests/test_search_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from marqo.core.unstructured_vespa_index.unstructured_vespa_index import UnstructuredVespaIndex
from marqo.tensor_search.models.api_models import SearchQuery
from pydantic import ValidationError
import marqo.api.exceptions as api_exceptions


class TestSearch(MarqoTestCase):
Expand Down Expand Up @@ -822,7 +823,6 @@ def test_wildcard_lexical_query(self):
index_name=index.name,
search_method=SearchMethod.LEXICAL, result_count=10,
filter=filter_term)
print(res)
self.assertIn("hits", res)
self.assertEqual(expected_count, len(res['hits']))

Expand Down Expand Up @@ -1012,3 +1012,38 @@ def test_lexical_search_DoesNotErrorWithEscapedQuotes(self):
)
self.assertEqual(len(expected_ids), len(res['hits']))
self.assertEqual(set(expected_ids), {hit['_id'] for hit in res['hits']})

def test_search_private_image_return_proper_error(self):
"""A test to ensure that InvalidArgumentError is raised when searching for a private image."""
test_queries_list = [
("https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small", "A private image"),
({"https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small": 1, "test": 1},
"A private image in the dictionary")
]

for index_name in [self.structured_default_image_index, self.unstructured_default_image_index]:
for query, msg in test_queries_list:
with self.subTest(f"{index_name} - {query}"):
with self.assertRaises(api_exceptions.InvalidArgError) as e:
tensor_search.search(
text=query, config=self.config, index_name=index_name.name,
)
self.assertIn("Error downloading media file", str(e.exception))
self.assertIn("403 Client Error", str(e.exception))

def test_search_invalid_image_url_image_return_proper_error(self):
"""A test to ensure that InvalidArgumentError is raised when searching for an invalid image url."""
test_queries_list = [
("https://a-dummy-image-url.jpg", "A invalid image"),
({"https://a-dummy-image-url.jpg": 1, "test": 1},
"A invalid image in the dictionary")
]

for index_name in [self.structured_default_image_index, self.unstructured_default_image_index]:
for query, msg in test_queries_list:
with self.subTest(f"{index_name} - {query}"):
with self.assertRaises(api_exceptions.InvalidArgError) as e:
tensor_search.search(
text=query, config=self.config, index_name=index_name.name,
)
self.assertIn("Error vectorising content", str(e.exception))

0 comments on commit 3313327

Please sign in to comment.