Skip to content

Commit

Permalink
Fix snake case validation bug for allFields (#717)
Browse files Browse the repository at this point in the history
Fix bug where validation fails to catch dependent_fields (snake case) because it's within a list
  • Loading branch information
farshidz authored Jan 23, 2024
1 parent fad54b8 commit c445a58
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 30 deletions.
24 changes: 14 additions & 10 deletions src/marqo/tensor_search/models/index_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Dict, Any, Optional, List
from typing import Dict, Any, Optional, List, Union

from pydantic import root_validator

Expand Down Expand Up @@ -45,15 +45,19 @@ class IndexSettings(StrictBaseModel):
@root_validator(pre=True)
def validate_field_names(cls, values):
# Verify no snake case field names (pydantic won't catch these due to allow_population_by_field_name = True)
def validate_dict_keys(d: dict):
for key in d.keys():
if '_' in key:
raise ValueError(f"Invalid field name '{key}'. "
f"See Create Index API reference here https://docs.marqo.ai/2.0.0/API-Reference/Indexes/create_index/")
if isinstance(d[key], dict):
validate_dict_keys(d[key])

validate_dict_keys(values)
def validate_keys(d: Union[dict, list]):
if isinstance(d, dict):
for key in d.keys():
if '_' in key:
raise ValueError(f"Invalid field name '{key}'. "
f"See Create Index API reference here https://docs.marqo.ai/2.0.0/API-Reference/Indexes/create_index/")

validate_keys(d[key])
elif isinstance(d, list):
for item in d:
validate_keys(item)

validate_keys(values)

return values

Expand Down
71 changes: 51 additions & 20 deletions tests/tensor_search/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,23 +145,54 @@ def test_invalid_argument_error(self):
assert "Could not find model properties for" in response.json()["message"]

def test_create_index_snake_case_fails(self):
# Verify snake case rejected for fields that have camel case as alias
response = self.client.post(
"/indexes/my_index",
json={
"type": "structured",
"allFields": [{"name": "field1", "type": "text"}],
"tensorFields": [],
'annParameters': {
'spaceType': 'dotproduct',
'parameters': {
'ef_construction': 128,
'm': 16
}
}
}
)

self.assertEqual(response.status_code, 422)
self.assertTrue("Invalid field name 'ef_construction'" in response.text)

"""
Verify snake case rejected for fields that have camel case as alias
"""
test_cases = [
({
"type": "structured",
"allFields": [
{
"name": "field1",
"type": "text"
},
{
"name": "field2",
"type": "multimodal_combination",
"dependent_fields": ["field1"]
}
],
"tensorFields": [],
}, 'dependent_fields', 'Snake case within a list'),
({
"type": "structured",
"allFields": [],
"tensorFields": [],
'annParameters': {
'spaceType': 'dotproduct',
'parameters': {
'ef_construction': 128,
'm': 16
}
}
}, 'ef_construction', 'Snake case within a dict'),
({
"type": "unstructured",
'annParameters': {
'spaceType': 'dotproduct',
'parameters': {
'ef_construction': 128,
'm': 16
}
}
}, 'ef_construction', 'Snake case within a dict, unstructured index')
]
for test_case, field, test_name in test_cases:
with self.subTest(test_name):
response = self.client.post(
"/indexes/my_index",
json=test_case
)

self.assertEqual(response.status_code, 422)
self.assertTrue(f"Invalid field name '{field}'" in response.text)

0 comments on commit c445a58

Please sign in to comment.