Skip to content

Commit

Permalink
Fix LanguageBind models normalization parameters parse and always gen…
Browse files Browse the repository at this point in the history
…erated normalised embeddings bug (#1031)

This PR fixes two bugs:
- 1. An internal error is raised if we pass a string to the vectorise function for Languagebind models;
- 2. Languagebind models always generate normalised embeddings
  • Loading branch information
wanliAlex authored Nov 4, 2024
1 parent 7737c8d commit d78a563
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 18 deletions.
13 changes: 8 additions & 5 deletions src/marqo/s2_inference/multimodal_model_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,13 @@ def preprocessor(self, modality):
raise ValueError("Model has not been loaded yet. Call _load_model() first.")
return self.encoder.preprocessor(modality)

def encode(self, content, modality, media_download_headers: Optional[Dict]=None, **kwargs):
def encode(self, content, modality, media_download_headers: Optional[Dict]=None, normalize=True, **kwargs):
if self.encoder is None:
raise ValueError("Model has not been loaded yet. Call _load_model() first.")
return self.encoder.encode(content, modality, media_download_headers, **kwargs)
return self.encoder.encode(
content=content, modality=modality, media_download_headers=media_download_headers,
normalize=normalize, **kwargs
)


class ModelEncoder(ABC):
Expand Down Expand Up @@ -255,7 +258,7 @@ def preprocessor(self, modality):

return self._preprocessors.get(modality)

def encode(self, content, modality, normalize=True, media_download_headers: Optional[Dict]=None, **kwargs):
def encode(self, content, modality, media_download_headers: Optional[Dict]=None, normalize=True, **kwargs):
inputs = {}

if modality == Modality.TEXT:
Expand All @@ -275,7 +278,7 @@ def encode(self, content, modality, normalize=True, media_download_headers: Opti
elif isinstance(content, str) and "http" in content:
self._download_content(content, temp_filename, media_download_headers)
else:
return self.encode([content], modality=Modality.TEXT)
return self.encode([content], normalize=normalize, modality=Modality.TEXT)

preprocessed_image = self.preprocessor(Modality.IMAGE)([temp_filename], return_tensors='pt')
inputs['image'] = to_device(preprocessed_image, self.model.device)['pixel_values']
Expand All @@ -292,7 +295,7 @@ def encode(self, content, modality, normalize=True, media_download_headers: Opti
# If media has already been preprocessed
inputs[modality.value] = to_device(content[0], self.model.device)['pixel_values']
elif isinstance(content[0], str) and 'http' in content[0]:
return self.encode(content[0], modality=modality, media_download_headers=media_download_headers)
return self.encode(content[0], modality=modality, normalize=normalize, media_download_headers=media_download_headers)
else:
raise ValueError(f"Unsupported {modality.value} content type: {type(content)}, content: {content}")

Expand Down
101 changes: 88 additions & 13 deletions tests/s2_inference/test_large_model_encoding.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import functools
import os
import torch
import pytest
from marqo.s2_inference.types import FloatTensor
from marqo.s2_inference.s2_inference import clear_loaded_models, get_model_properties_from_registry, \
_convert_tensor_to_numpy
import unittest
from unittest.mock import patch

import numpy as np
import unittest
import pytest
import torch

from marqo.s2_inference.s2_inference import (
_check_output_type, vectorise,
_convert_vectorized_output,
)
import functools
from marqo.s2_inference.s2_inference import _load_model as og_load_model
from marqo.s2_inference.s2_inference import clear_loaded_models, get_model_properties_from_registry, \
_convert_tensor_to_numpy
from marqo.s2_inference.types import FloatTensor
from tests.marqo_test import TestImageUrls

_load_model = functools.partial(og_load_model, calling_func="unit_test")
from marqo.s2_inference.multimodal_model_load import Modality
from marqo.s2_inference.configs import ModelCache
import shutil

_load_model = functools.partial(og_load_model, calling_func="unit_test")


def remove_cached_model_files():
'''
Expand Down Expand Up @@ -367,10 +370,82 @@ def test_multilingual_e5_model_performance(self):
assert np.allclose(english_feature, other_language_feature, atol=e)

def test_cuda_encode_type(self):
run_test_cuda_encode_type(self.models + ["fp16/ViT-B/32", "open_clip/convnext_base_w/laion2b_s13b_b82k",
"open_clip/convnext_base_w_320/laion_aesthetic_s13b_b82k_augreg",
"all-MiniLM-L6-v1", "all_datasets_v4_MiniLM-L6", "hf/all-MiniLM-L6-v1",
"hf/all_datasets_v4_MiniLM-L6"])
run_test_cuda_encode_type(
self.models + ["fp16/ViT-B/32", "open_clip/convnext_base_w/laion2b_s13b_b82k",
"open_clip/convnext_base_w_320/laion_aesthetic_s13b_b82k_augreg",
"all-MiniLM-L6-v1", "all_datasets_v4_MiniLM-L6", "hf/all-MiniLM-L6-v1",
"hf/all_datasets_v4_MiniLM-L6",]
)


@pytest.mark.largemodel
@pytest.mark.skipif(torch.cuda.is_available() is False,
reason="We skip the large model test if we don't have cuda support")
class TestLanguageBindModels(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
clear_loaded_models()

@classmethod
def tearDownClass(cls) -> None:
clear_loaded_models()

def setUp(self):
self.models = ["LanguageBind/Video_V1.5_FT_Audio_FT_Image"]
self.device="cuda"

def _help_test_vectorise(self, model_name, modality, test_content_list):
for content in test_content_list:
with self.subTest(model=model_name, content=content, normalized=True):
normalized_embeddings_list = vectorise(
model_name=model_name,
content=content, device=self.device, normalize_embeddings=True,
modality=modality
)
for embeddings in normalized_embeddings_list:
self.assertTrue(np.linalg.norm(np.array(embeddings)) - 1 < 1e-6)

if modality != Modality.TEXT: # Text embeddings are always normalized
with self.subTest(model=model_name, content=content, normalized=False):
unnormalized_embeddings_list = vectorise(
model_name=model_name,
content=content, device=self.device, normalize_embeddings=False,
modality=modality
)
for embeddings in unnormalized_embeddings_list:
self.assertTrue(np.linalg.norm(np.array(embeddings)) - 1 > 1e-2)

def test_models(self):
test_cases = {
Modality.TEXT: ["test", ["test2", "test3"]],
Modality.AUDIO: [
"https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/4-145081-A-9.wav",
[
"https://marqo-ecs-50-audio-test-dataset.s3.us-east-1.amazonaws.com/audios/1-115920-A-22.wav",
"https://marqo-ecs-50-audio-test-dataset.s3.us-east-1.amazonaws.com/audios/1-115920-A-22.wav"
]
],
Modality.IMAGE: [
'https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image0.jpg',
[
'https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg',
'https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg'
]
],
Modality.VIDEO: [
'https://marqo-k400-video-test-dataset.s3.us-east-1.amazonaws.com/videos/--bO6XwZ9HI_000041_000051.mp4',
[
'https://marqo-k400-video-test-dataset.s3.us-east-1.amazonaws.com/videos/-0MVWb7nJLY_000008_000018.mp4',
'https://marqo-k400-video-test-dataset.s3.us-east-1.amazonaws.com/videos/-0oMsq-9b6c_000095_000105.mp4'
]
]
}

for model_name in self.models:
for modality, test_content_list in test_cases.items():
with self.subTest(model=model_name, modality=modality):
self._help_test_vectorise(model_name, modality, test_content_list)


@pytest.mark.largemodel
Expand Down

0 comments on commit d78a563

Please sign in to comment.