Skip to content

Commit

Permalink
DynamoDB: Improve validation (getmoto#6986)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Nov 4, 2023
1 parent 87f816f commit 9136030
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 90 deletions.
6 changes: 3 additions & 3 deletions moto/dynamodb/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def get_filter_expression(
expr: Optional[str],
names: Optional[Dict[str, str]],
values: Optional[Dict[str, str]],
values: Optional[Dict[str, Dict[str, str]]],
) -> Union["Op", "Func"]:
"""
Parse a filter expression into an Op.
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
self,
condition_expression: Optional[str],
expression_attribute_names: Optional[Dict[str, str]],
expression_attribute_values: Optional[Dict[str, str]],
expression_attribute_values: Optional[Dict[str, Dict[str, str]]],
):
self.condition_expression = condition_expression
self.expression_attribute_names = expression_attribute_names
Expand Down Expand Up @@ -423,7 +423,7 @@ def _parse_path_element(self, name: str) -> Node:
children=[],
)

def _lookup_expression_attribute_value(self, name: str) -> str:
def _lookup_expression_attribute_value(self, name: str) -> Dict[str, str]:
return self.expression_attribute_values[name] # type: ignore[index]

def _lookup_expression_attribute_name(self, name: str) -> str:
Expand Down
6 changes: 6 additions & 0 deletions moto/dynamodb/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,9 @@ def __init__(self) -> None:
class SerializationException(DynamodbException):
def __init__(self, msg: str):
super().__init__(error_type="SerializationException", message=msg)


class UnknownKeyType(MockValidationException):
def __init__(self, key_type: str, position: str):
msg = f"1 validation error detected: Value '{key_type}' at '{position}' failed to satisfy constraint: Member must satisfy enum value set: [HASH, RANGE]"
super().__init__(msg)
2 changes: 1 addition & 1 deletion moto/dynamodb/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def query(
projection_expressions: Optional[List[List[str]]],
index_name: Optional[str] = None,
expr_names: Optional[Dict[str, str]] = None,
expr_values: Optional[Dict[str, str]] = None,
expr_values: Optional[Dict[str, Dict[str, str]]] = None,
filter_expression: Optional[str] = None,
**filter_kwargs: Any,
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
Expand Down
6 changes: 3 additions & 3 deletions moto/dynamodb/parsing/key_condition_expression.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, List, Dict, Tuple, Optional
from typing import Any, List, Dict, Tuple, Optional, Union
from moto.dynamodb.exceptions import MockValidationException
from moto.utilities.tokenizer import GenericTokenizer

Expand All @@ -19,7 +19,7 @@ def get_key(schema: List[Dict[str, str]], key_type: str) -> Optional[str]:

def parse_expression(
key_condition_expression: str,
expression_attribute_values: Dict[str, str],
expression_attribute_values: Dict[str, Dict[str, str]],
expression_attribute_names: Dict[str, str],
schema: List[Dict[str, str]],
) -> Tuple[Dict[str, Any], Optional[str], List[Dict[str, Any]]]:
Expand All @@ -35,7 +35,7 @@ def parse_expression(
current_stage: Optional[EXPRESSION_STAGES] = None
current_phrase = ""
key_name = comparison = ""
key_values = []
key_values: List[Union[Dict[str, str], str]] = []
results: List[Tuple[str, str, Any]] = []
tokenizer = GenericTokenizer(key_condition_expression)
for crnt_char in tokenizer:
Expand Down
43 changes: 37 additions & 6 deletions moto/dynamodb/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .exceptions import (
MockValidationException,
ResourceNotFoundException,
UnknownKeyType,
)
from moto.dynamodb.models import dynamodb_backends, Table, DynamoDBBackend
from moto.dynamodb.models.utilities import dynamo_json_dump
Expand Down Expand Up @@ -242,21 +243,42 @@ def create_table(self) -> str:
sse_spec = body.get("SSESpecification")
# getting the schema
key_schema = body["KeySchema"]
for idx, _key in enumerate(key_schema, start=1):
key_type = _key["KeyType"]
if key_type not in ["HASH", "RANGE"]:
raise UnknownKeyType(
key_type=key_type, position=f"keySchema.{idx}.member.keyType"
)
# getting attribute definition
attr = body["AttributeDefinitions"]
# getting the indexes

# getting/validating the indexes
global_indexes = body.get("GlobalSecondaryIndexes")
if global_indexes == []:
raise MockValidationException(
"One or more parameter values were invalid: List of GlobalSecondaryIndexes is empty"
)
global_indexes = global_indexes or []
for idx, g_idx in enumerate(global_indexes, start=1):
for idx2, _key in enumerate(g_idx["KeySchema"], start=1):
key_type = _key["KeyType"]
if key_type not in ["HASH", "RANGE"]:
position = f"globalSecondaryIndexes.{idx}.member.keySchema.{idx2}.member.keyType"
raise UnknownKeyType(key_type=key_type, position=position)

local_secondary_indexes = body.get("LocalSecondaryIndexes")
if local_secondary_indexes == []:
raise MockValidationException(
"One or more parameter values were invalid: List of LocalSecondaryIndexes is empty"
)
local_secondary_indexes = local_secondary_indexes or []
for idx, g_idx in enumerate(local_secondary_indexes, start=1):
for idx2, _key in enumerate(g_idx["KeySchema"], start=1):
key_type = _key["KeyType"]
if key_type not in ["HASH", "RANGE"]:
position = f"localSecondaryIndexes.{idx}.member.keySchema.{idx2}.member.keyType"
raise UnknownKeyType(key_type=key_type, position=position)

# Verify AttributeDefinitions list all
expected_attrs = []
expected_attrs.extend([key["AttributeName"] for key in key_schema])
Expand Down Expand Up @@ -462,7 +484,7 @@ def put_item(self) -> str:
# expression
condition_expression = self.body.get("ConditionExpression")
expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
expression_attribute_values = self._get_expr_attr_values()

if condition_expression:
overwrite = False
Expand Down Expand Up @@ -650,7 +672,7 @@ def query(self) -> str:
projection_expression = self._get_projection_expression()
expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
filter_expression = self._get_filter_expression()
expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
expression_attribute_values = self._get_expr_attr_values()

projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names
Expand Down Expand Up @@ -776,7 +798,7 @@ def scan(self) -> str:
filters[attribute_name] = (comparison_operator, comparison_values)

filter_expression = self._get_filter_expression()
expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
expression_attribute_values = self._get_expr_attr_values()
expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
projection_expression = self._get_projection_expression()
exclusive_start_key = self.body.get("ExclusiveStartKey")
Expand Down Expand Up @@ -824,7 +846,7 @@ def delete_item(self) -> str:
# expression
condition_expression = self.body.get("ConditionExpression")
expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
expression_attribute_values = self._get_expr_attr_values()

item = self.dynamodb_backend.delete_item(
name,
Expand Down Expand Up @@ -879,7 +901,7 @@ def update_item(self) -> str:
# expression
condition_expression = self.body.get("ConditionExpression")
expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
expression_attribute_values = self._get_expr_attr_values()

item = self.dynamodb_backend.update_item(
name,
Expand Down Expand Up @@ -920,6 +942,15 @@ def update_item(self) -> str:
)
return dynamo_json_dump(item_dict)

def _get_expr_attr_values(self) -> Dict[str, Dict[str, str]]:
values = self.body.get("ExpressionAttributeValues", {})
for key in values.keys():
if not key.startswith(":"):
raise MockValidationException(
f'ExpressionAttributeValues contains invalid key: Syntax error; key: "{key}"'
)
return values

def _build_updated_new_attributes(self, original: Any, changed: Any) -> Any:
if type(changed) != type(original):
return changed
Expand Down
82 changes: 45 additions & 37 deletions tests/test_dynamodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from uuid import uuid4


def dynamodb_aws_verified(func):
def dynamodb_aws_verified(create_table: bool = True):
"""
Function that is verified to work against AWS.
Can be run against AWS at any time by setting:
Expand All @@ -19,39 +19,47 @@ def dynamodb_aws_verified(func):
- Delete the table
"""

@wraps(func)
def pagination_wrapper():
client = boto3.client("dynamodb", region_name="us-east-1")
table_name = "t" + str(uuid4())[0:6]

allow_aws_request = (
os.environ.get("MOTO_TEST_ALLOW_AWS_REQUEST", "false").lower() == "true"
)

if allow_aws_request:
print(f"Test {func} will create DynamoDB Table {table_name}")
resp = create_table_and_test(table_name, client)
else:
with mock_dynamodb():
resp = create_table_and_test(table_name, client)
return resp

def create_table_and_test(table_name, client):
client.create_table(
TableName=table_name,
KeySchema=[{"AttributeName": "pk", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "pk", "AttributeType": "S"}],
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 5},
Tags=[{"Key": "environment", "Value": "moto_tests"}],
)
waiter = client.get_waiter("table_exists")
waiter.wait(TableName=table_name)
try:
resp = func(table_name)
finally:
### CLEANUP ###
client.delete_table(TableName=table_name)

return resp

return pagination_wrapper
def inner(func):
@wraps(func)
def pagination_wrapper():
client = boto3.client("dynamodb", region_name="us-east-1")
table_name = "t" + str(uuid4())[0:6]

allow_aws_request = (
os.environ.get("MOTO_TEST_ALLOW_AWS_REQUEST", "false").lower() == "true"
)

if allow_aws_request:
if create_table:
print(f"Test {func} will create DynamoDB Table {table_name}")
return create_table_and_test(table_name, client)
else:
return func()
else:
with mock_dynamodb():
if create_table:
return create_table_and_test(table_name, client)
else:
return func()

def create_table_and_test(table_name, client):
client.create_table(
TableName=table_name,
KeySchema=[{"AttributeName": "pk", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "pk", "AttributeType": "S"}],
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 5},
Tags=[{"Key": "environment", "Value": "moto_tests"}],
)
waiter = client.get_waiter("table_exists")
waiter.wait(TableName=table_name)
try:
resp = func(table_name)
finally:
### CLEANUP ###
client.delete_table(TableName=table_name)

return resp

return pagination_wrapper

return inner
35 changes: 34 additions & 1 deletion tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ def test_query_with_missing_expression_attribute():


@pytest.mark.aws_verified
@dynamodb_aws_verified
@dynamodb_aws_verified()
def test_update_item_returns_old_item(table_name=None):
dynamodb = boto3.resource("dynamodb", region_name="us-east-1")
table = dynamodb.Table(table_name)
Expand Down Expand Up @@ -1164,3 +1164,36 @@ def test_update_item_returns_old_item(table_name=None):
"lock": {"M": {"acquired_at": {"N": "123"}}},
"pk": {"S": "mark"},
}


@pytest.mark.aws_verified
@dynamodb_aws_verified()
def test_scan_with_missing_value(table_name=None):
dynamodb = boto3.resource("dynamodb", region_name="us-east-1")
table = dynamodb.Table(table_name)

with pytest.raises(ClientError) as exc:
table.scan(
FilterExpression="attr = loc",
# Missing ':'
ExpressionAttributeValues={"loc": "sth"},
)
err = exc.value.response["Error"]
assert err["Code"] == "ValidationException"
assert (
err["Message"]
== 'ExpressionAttributeValues contains invalid key: Syntax error; key: "loc"'
)

with pytest.raises(ClientError) as exc:
table.query(
KeyConditionExpression="attr = loc",
# Missing ':'
ExpressionAttributeValues={"loc": "sth"},
)
err = exc.value.response["Error"]
assert err["Code"] == "ValidationException"
assert (
err["Message"]
== 'ExpressionAttributeValues contains invalid key: Syntax error; key: "loc"'
)
4 changes: 2 additions & 2 deletions tests/test_dynamodb/test_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3686,7 +3686,7 @@ def test_transact_write_items_put():


@pytest.mark.aws_verified
@dynamodb_aws_verified
@dynamodb_aws_verified()
def test_transact_write_items_put_conditional_expressions(table_name=None):
dynamodb = boto3.client("dynamodb", region_name="us-east-1")

Expand Down Expand Up @@ -3731,7 +3731,7 @@ def test_transact_write_items_put_conditional_expressions(table_name=None):


@pytest.mark.aws_verified
@dynamodb_aws_verified
@dynamodb_aws_verified()
def test_transact_write_items_failure__return_item(table_name=None):
dynamodb = boto3.client("dynamodb", region_name="us-east-1")
dynamodb.put_item(TableName=table_name, Item={"pk": {"S": "foo2"}})
Expand Down
Loading

0 comments on commit 9136030

Please sign in to comment.