Skip to content

Commit

Permalink
Feature: No model (#789)
Browse files Browse the repository at this point in the history
Add no_model and optional search query feature
  • Loading branch information
wanliAlex authored Mar 27, 2024
1 parent d051df7 commit 1a9ca83
Show file tree
Hide file tree
Showing 22 changed files with 715 additions and 550 deletions.
24 changes: 23 additions & 1 deletion src/marqo/core/models/marqo_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from marqo.exceptions import InvalidArgumentError
from marqo.logging import get_logger
from marqo.s2_inference import s2_inference
from marqo.s2_inference.errors import UnknownModelError
from marqo.s2_inference.errors import UnknownModelError, InvalidModelPropertiesError

logger = get_logger(__name__)

Expand Down Expand Up @@ -123,6 +123,24 @@ class Model(StrictBaseModel):
properties: Optional[Dict[str, Any]]
custom: bool = False

@root_validator(pre=False)
def validate_custom_properties(cls, values):
"""Validate custom model properties.
Raises:
InvalidArgumentError: If model properties are invalid.
"""
model_name = values.get('name')
properties = values.get('properties')
custom = values.get('custom')
if properties and custom:
try:
s2_inference.validate_model_properties(model_name, properties)
except InvalidModelPropertiesError as e:
raise ValueError(
f'Invalid model properties for model={model_name}. Reason: {e}.')
return values

def dict(self, *args, **kwargs):
"""
Custom dict method that removes the properties field if the model is not custom. This ensures we don't store
Expand Down Expand Up @@ -166,6 +184,10 @@ def _update_model_properties_from_registry(self) -> None:
f'Could not find model properties for model={model_name}. '
f'Please check that the model name is correct. '
f'Please provide model_properties if the model is a custom model and is not supported by default')
except InvalidModelPropertiesError as e:
raise InvalidArgumentError(
f'Invalid model properties for model={model_name}. Reason: {e}.'
)


class MarqoIndex(ImmutableStrictBaseModel, ABC):
Expand Down
8 changes: 8 additions & 0 deletions src/marqo/marqo_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,11 @@ def list_of_models():
def search_context():
return _build_url('API-Reference/Search/search/#context')


def configuring_preloaded_models():
return _build_url('Guides/Advanced-Usage/configuration/#configuring-preloaded-models')


def bring_your_own_model():
return _build_url('Guides/Models-Reference/bring_your_own_model/')

18 changes: 16 additions & 2 deletions src/marqo/s2_inference/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from marqo.s2_inference.sbert_onnx_utils import SBERT_ONNX
from marqo.s2_inference.sbert_utils import SBERT, TEST
from marqo.s2_inference.types import Dict
from marqo.s2_inference.no_model_utils import NO_MODEL


# we need to keep track of the embed dim and model load functions/classes
Expand Down Expand Up @@ -783,7 +784,7 @@ def _get_onnx_clip_properties() -> Dict:
"resolution": 224,
"pretrained": "laionb_s32b_b82k",
"image_mean": (0.5, 0.5, 0.5),
"image_std": (0.5, 0.5, 0.5),
"image_std": (0.5, 0.5, 0.5),
},
"onnx32/open_clip/ViT-L-14-336/openai":
{
Expand Down Expand Up @@ -1734,6 +1735,15 @@ def _get_random_properties() -> Dict:
}
return RANDOM_MODEL_PROPERTIES


def _get_no_model_properties() -> Dict:
return {
'no_model': {
'type': 'no_model',
'note': "This is a special model no_model that requires users to provide 'dimensions'"
}
}

def _get_model_load_mappings() -> Dict:
return {'clip':CLIP,
'open_clip': OPEN_CLIP,
Expand All @@ -1744,7 +1754,8 @@ def _get_model_load_mappings() -> Dict:
"multilingual_clip" : MULTILINGUAL_CLIP,
"fp16_clip": FP16_CLIP,
'random':Random,
'hf':HF_MODEL}
'hf':HF_MODEL,
"no_model": NO_MODEL}

def load_model_properties() -> Dict:
# also truncate the name if not already
Expand All @@ -1761,6 +1772,7 @@ def load_model_properties() -> Dict:
onnx_clip_model_properties = _get_onnx_clip_properties()
multilingual_clip_model_properties = get_multilingual_clip_properties()
fp16_clip_model_properties = _get_fp16_clip_properties()
no_model_properties = _get_no_model_properties()

# combine the above dicts
model_properties = dict(clip_model_properties.items())
Expand All @@ -1773,6 +1785,8 @@ def load_model_properties() -> Dict:
model_properties.update(onnx_clip_model_properties)
model_properties.update(multilingual_clip_model_properties)
model_properties.update(fp16_clip_model_properties)
model_properties.update(no_model_properties)


all_properties = dict()
all_properties['models'] = model_properties
Expand Down
Empty file.
16 changes: 16 additions & 0 deletions src/marqo/s2_inference/models/model_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from enum import Enum


class ModelType(str, Enum):
"""Enums for the different types of models that can be used for inference."""
OpenCLIP = "open_clip"
CLIP = 'clip'
SBERT = 'sbert'
Test = 'test'
SBERT_ONNX = 'sbert_onnx'
CLIP_ONNX = 'clip_onnx'
MultilingualClip = "multilingual_clip"
FP16_CLIP = "fp16_clip"
Random = 'random'
HF_MODEL = 'hf'
NO_MODEL= "no_model"
16 changes: 16 additions & 0 deletions src/marqo/s2_inference/no_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from marqo.s2_inference.errors import VectoriseError
from marqo.s2_inference.sbert_utils import Model
from marqo.s2_inference.models.model_type import ModelType


class NO_MODEL(Model):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def load(self, *args, **kwargs) -> None:
pass

def encode(self, *args, **kwargs) -> None:
raise VectoriseError(f"Cannot vectorise anything with '{ModelType.NO_MODEL}'. "
f"This model is intended for adding documents and searching with custom vectors only. "
f"If vectorisation is needed, please use a different model ")
100 changes: 75 additions & 25 deletions src/marqo/s2_inference/s2_inference.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
"""This is the interface for interacting with S2 Inference
The functions defined here would have endpoints, later on.
"""
import datetime
import threading

import numpy as np
from marqo.api.exceptions import ModelCacheManagementError, InvalidArgError, ConfigurationError, InternalError
import torch
from PIL import UnidentifiedImageError

from marqo import marqo_docs
from marqo.api.exceptions import ModelCacheManagementError, ConfigurationError, InternalError
from marqo.s2_inference import constants
from marqo.s2_inference.configs import get_default_normalization, get_default_seq_length
from marqo.s2_inference.errors import (
VectoriseError, InvalidModelPropertiesError, ModelLoadError,
UnknownModelError, ModelNotInCacheError, ModelDownloadError, S2InferenceError)
from PIL import UnidentifiedImageError
UnknownModelError, ModelNotInCacheError, ModelDownloadError)
from marqo.s2_inference.logger import get_logger
from marqo.s2_inference.model_registry import load_model_properties
from marqo.s2_inference.configs import get_default_normalization, get_default_seq_length
from marqo.s2_inference.models.model_type import ModelType
from marqo.s2_inference.types import *
from marqo.s2_inference.logger import get_logger
import torch
import datetime
from marqo.s2_inference import constants
from marqo.tensor_search.enums import AvailableModelsKey
from marqo.tensor_search.configs import EnvVars
from marqo.tensor_search.enums import AvailableModelsKey
from marqo.tensor_search.models.private_models import ModelAuth
import threading
from marqo.tensor_search.utils import read_env_vars_and_defaults, generate_batches
from marqo.tensor_search.configs import EnvVars

logger = get_logger(__name__)

Expand Down Expand Up @@ -56,7 +59,7 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties:
if not device:
raise InternalError(message=f"vectorise (internal function) cannot be called without setting device!")

validated_model_properties = _validate_model_properties(model_name, model_properties)
validated_model_properties = validate_model_properties(model_name, model_properties)
model_cache_key = _create_model_cache_key(model_name, device, validated_model_properties)

_update_available_models(
Expand All @@ -78,7 +81,7 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties:
else:
vectorised = np.concatenate(vector_batches, axis=0)
except UnidentifiedImageError as e:
raise VectoriseError(str(e)) from e
raise VectoriseError(f"Could not process given image: {content}") from e

return _convert_vectorized_output(vectorised)

Expand Down Expand Up @@ -182,39 +185,82 @@ def _update_available_models(model_cache_key: str, model_name: str, validated_mo
f"Please wait for 10 seconds and send the request again.\n")


def _validate_model_properties(model_name: str, model_properties: dict) -> dict:
"""validate model_properties, if not given then return model_registry properties
def validate_model_properties(model_name: str, model_properties: dict) -> dict:
"""validate model_properties, if not given then return model_registry properties.
This is a rough validation as it only checks the minimum required fields and dimensions values. More indepth
check should be done when loading the model.
Raises:
InvalidModelPropertiesError: if the model_properties are invalid
UnknownModelError: if the model_name is not in the model registry
"""
if model_properties is not None:
"""checks model dict to see if all required keys are present
"""
required_keys = []
if model_properties.get("type", None) in (None, "sbert"):

if "type" not in model_properties:
error_message_postfix = "Marqo is loading the model with default type 'sbert' as the type was not provided."
else:
error_message_postfix = ""

model_type = model_properties.get("type", None)

if model_type in (None, ModelType.SBERT):
required_keys = ["dimensions", "name"]
# updates model dict with default values if optional keys are missing for sbert
optional_keys_values = [("type", "sbert"), ("tokens", get_default_seq_length())]
optional_keys_values = [("type", ModelType.SBERT), ("tokens", get_default_seq_length())]
for key, value in optional_keys_values:
if key not in model_properties:
model_properties[key] = value

elif model_properties.get("type", None) in ("clip", "open_clip"):
elif model_type in (ModelType.OpenCLIP, ModelType.CLIP):
required_keys = ["name", "dimensions"]

elif model_properties.get("type", None) in ("hf", ):
elif model_type in (ModelType.HF_MODEL, ):
required_keys = ["dimensions"]
elif model_type in (ModelType.NO_MODEL,):
required_keys = ["dimensions"]
if not model_name == "no_model":
raise InvalidModelPropertiesError(f"To use the 'no_model' feature, you must provide 'model = no_model' "
f"and 'type = no_model', but received 'model = {model_name}' and "
f"'type = {model_type}'.")
elif model_type in (ModelType.Test, ModelType.Random, ModelType.MultilingualClip, ModelType.FP16_CLIP,
ModelType.SBERT_ONNX, ModelType.CLIP_ONNX):
pass
else:
raise InvalidModelPropertiesError(f"Invalid model type. Please check the model type in model_properties. "
f"Supported model types are '{ModelType.SBERT}', '{ModelType.OpenCLIP}', "
f"'{ModelType.CLIP}', '{ModelType.HF_MODEL}', '{ModelType.NO_MODEL}', "
f"'{ModelType.Test}', '{ModelType.Random}', '{ModelType.MultilingualClip}', "
f"'{ModelType.FP16_CLIP}', '{ModelType.SBERT_ONNX}', '{ModelType.CLIP_ONNX}' ")

for key in required_keys:
if key not in model_properties:
raise InvalidModelPropertiesError(f"model_properties has missing key '{key}'."
f"please update your model properties with required key `{key}`"
f"check `https://docs.marqo.ai/0.0.12/Models-Reference/dense_retrieval/` for more info.")
raise InvalidModelPropertiesError(f"model_properties has missing key '{key}'. "
f"please update your model properties with required key `{key}`. "
f"{error_message_postfix} "
f"check {marqo_docs.list_of_models()}, "
f"{marqo_docs.bring_your_own_model()} for more info")

else:
model_properties = get_model_properties_from_registry(model_name)

_validate_model_properties_dimension(model_properties.get("dimensions", None))

return model_properties


def _validate_model_properties_dimension(dimensions: Optional[int]) -> None:
"""Validate the dimensions value in model_properties as the dimensions value must be a positive integer.
Raises:
InvalidModelPropertiesError: if the dimensions value is invalid
"""
if dimensions is None or not isinstance(dimensions, int) or dimensions < 1:
raise InvalidModelPropertiesError(
f"Invalid model properties: 'dimensions' must be a positive integer, but received {dimensions}.")


def _validate_model_into_device(model_name:str, model_properties: dict, device: str, calling_func: str = None) -> bool:
'''
Note: this function should only be called by `_update_available_models` for threading safeness.
Expand Down Expand Up @@ -373,7 +419,11 @@ def get_model_properties_from_registry(model_name: str) -> dict:
raise UnknownModelError(f"Could not find model properties in model registry for model={model_name}. "
f"Model is not supported by default.")

return MODEL_PROPERTIES['models'][model_name]
model_properties = MODEL_PROPERTIES['models'][model_name]

validate_model_properties(model_name, model_properties)

return model_properties


def _check_output_type(output: List[List[float]]) -> bool:
Expand Down
23 changes: 22 additions & 1 deletion src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""The API entrypoint for Tensor Search"""
import json
from typing import List

import pydantic
import uvicorn
from fastapi import FastAPI
from fastapi import Request, Depends
Expand All @@ -10,8 +12,10 @@
from marqo import exceptions as base_exceptions
from marqo import version
from marqo.api import exceptions as api_exceptions
from marqo.api.exceptions import InvalidArgError
from marqo.api.models.health_response import HealthResponse
from marqo.api.models.rollback_request import RollbackRequest
from marqo.api.models.update_documents import UpdateDocumentsBodyParams
from marqo.api.route import MarqoCustomRoute
from marqo.core import exceptions as core_exceptions
from marqo.core.index_management.index_management import IndexManagement
Expand All @@ -28,7 +32,6 @@
from marqo.upgrades.upgrade import UpgradeRunner, RollbackRunner
from marqo.vespa import exceptions as vespa_exceptions
from marqo.vespa.vespa_client import VespaClient
from marqo.api.models.update_documents import UpdateDocumentsBodyParams

logger = get_logger(__name__)

Expand Down Expand Up @@ -136,6 +139,24 @@ def marqo_api_exception_handler(request: Request, exc: api_exceptions.MarqoWebEr
return JSONResponse(content=body, status_code=exc.status_code)


@app.exception_handler(pydantic.ValidationError)
async def validation_exception_handler(request, exc: pydantic.ValidationError) -> JSONResponse:
"""Catch pydantic validation errors and rewrite as an InvalidArgError whilst keeping error messages from the ValidationError."""
error_messages = [{
'loc': error.get('loc', ''),
'msg': error.get('msg', ''),
'type': error.get('type', '')
} for error in exc.errors()]

body = {
"message": json.dumps(error_messages),
"code": InvalidArgError.code,
"type": InvalidArgError.error_type,
"link": InvalidArgError.link
}
return JSONResponse(content=body, status_code=InvalidArgError.status_code)


@app.exception_handler(api_exceptions.MarqoError)
def marqo_internal_exception_handler(request, exc: api_exceptions.MarqoError):
"""MarqoErrors are treated as internal errors"""
Expand Down
2 changes: 2 additions & 0 deletions src/marqo/tensor_search/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
ALLOWED_MULTIMODAL_FIELD_TYPES = [str]

ALLOWED_CUSTOM_VECTOR_CONTENT_TYPES = [str]

MODELS_TO_SKIP_PRELOADING = {"no_model"}
Loading

0 comments on commit 1a9ca83

Please sign in to comment.