Skip to content

Commit

Permalink
Fix validation bug for create index (#734)
Browse files Browse the repository at this point in the history
Fix bug were field names with underscore fail under dependentFields
  • Loading branch information
farshidz authored Jan 24, 2024
1 parent 4ccc4a4 commit 702d274
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
5 changes: 3 additions & 2 deletions src/marqo/tensor_search/models/index_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def validate_keys(d: Union[dict, list]):
if '_' in key:
raise ValueError(f"Invalid field name '{key}'. "
f"See Create Index API reference here https://docs.marqo.ai/2.0.0/API-Reference/Indexes/create_index/")

validate_keys(d[key])

if key not in ['dependentFields', 'modelProperties']:
validate_keys(d[key])
elif isinstance(d, list):
for item in d:
validate_keys(item)
Expand Down
55 changes: 50 additions & 5 deletions tests/tensor_search/test_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from unittest import mock

from fastapi.testclient import TestClient
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_create_index_snake_case_fails(self):
"""
Verify snake case rejected for fields that have camel case as alias
"""
test_cases = [
test_cases_fail = [
({
"type": "structured",
"allFields": [
Expand All @@ -158,8 +159,12 @@ def test_create_index_snake_case_fails(self):
},
{
"name": "field2",
"type": "text"
},
{
"name": "field3",
"type": "multimodal_combination",
"dependent_fields": ["field1"]
"dependent_fields": {"field1": 0.5, "field2": 0.5}
}
],
"tensorFields": [],
Expand All @@ -175,7 +180,7 @@ def test_create_index_snake_case_fails(self):
'm': 16
}
}
}, 'ef_construction', 'Snake case within a dict'),
}, 'ef_construction', 'Snake case within a dict is invalid'),
({
"type": "unstructured",
'annParameters': {
Expand All @@ -185,9 +190,39 @@ def test_create_index_snake_case_fails(self):
'm': 16
}
}
}, 'ef_construction', 'Snake case within a dict, unstructured index')
}, 'ef_construction', 'Snake case within a dict is invalid, unstructured index')
]
for test_case, field, test_name in test_cases:
test_cases_pass = [
({
"type": "structured",
"allFields": [
{
"name": "field_1",
"type": "text"
},
{
"name": "field_2",
"type": "text"
},
{
"name": "field_3",
"type": "multimodal_combination",
"dependentFields": {"field_1": 0.5, "field_2": 0.5}
}
],
"tensorFields": ['field_3'],
"model": "ViT-L/14",
"modelProperties": {
"name": "ViT-L/14",
"dimensions": 768,
"url": "https://7b4d1a66-507d-43f1-b99f-7368b655de46.s3.amazonaws.com/e5a7d9c7-0736-4301-a037-b1307f43a314/23fa0cb1-68d5-40f6-8039-e9e1265b6103.pt",
"type": "open_clip",
"field_1": "sth"
}
}, 'Snake case in field name is valid'),
]

for test_case, field, test_name in test_cases_fail:
with self.subTest(test_name):
response = self.client.post(
"/indexes/my_index",
Expand All @@ -196,3 +231,13 @@ def test_create_index_snake_case_fails(self):

self.assertEqual(response.status_code, 422)
self.assertTrue(f"Invalid field name '{field}'" in response.text)

for test_case, test_name in test_cases_pass:
with self.subTest(test_name):
index_name = 'a' + str(uuid.uuid4()).replace('-', '')
response = self.client.post(
f"/indexes/{index_name}",
json=test_case
)

self.assertEqual(response.status_code, 200)

0 comments on commit 702d274

Please sign in to comment.