Skip to content

Commit

Permalink
Support preset types in TypedData (#1377)
Browse files Browse the repository at this point in the history
  • Loading branch information
franciszekjob authored Jul 5, 2024
1 parent 6508f33 commit 1565660
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 82 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"types": {
"StarknetDomain": [
{ "name": "name", "type": "shortstring" },
{ "name": "version", "type": "shortstring" },
{ "name": "chainId", "type": "shortstring" },
{ "name": "revision", "type": "shortstring" }
],
"Example": [
{ "name": "n0", "type": "TokenAmount" },
{ "name": "n1", "type": "NftId" }
]
},
"primaryType": "Example",
"domain": {
"name": "StarkNet Mail",
"version": "1",
"chainId": "1",
"revision": "1"
},
"message": {
"n0": {
"token_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7",
"amount": {
"low": "0x3e8",
"high": "0x0"
}
},
"n1": {
"collection_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7",
"token_id": {
"low": "0x3e8",
"high": "0x0"
}
}
}
}
76 changes: 56 additions & 20 deletions starknet_py/utils/typed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ class BasicType(Enum):
TIMESTAMP = "timestamp"


class PresetType(Enum):
U256 = "u256"
TOKEN_AMOUNT = "TokenAmount"
NFT_ID = "NftId"


@dataclass(frozen=True)
class TypedData:
"""
Expand All @@ -120,6 +126,14 @@ class TypedData:
def __post_init__(self):
self._verify_types()

@property
def _all_types(self):
preset_types = _get_preset_types(self.domain.resolved_revision)
return {
**preset_types,
**self.types,
}

@property
def _hash_method(self) -> HashMethod:
if self.domain.resolved_revision == Revision.V0:
Expand Down Expand Up @@ -202,7 +216,7 @@ def _encode_value(
value: Union[int, str, dict, list],
context: Optional[TypeContext] = None,
) -> int:
if type_name in self.types and isinstance(value, dict):
if type_name in self._all_types and isinstance(value, dict):
return self.struct_hash(type_name, value)

if is_pointer(type_name) and isinstance(value, list):
Expand Down Expand Up @@ -245,7 +259,7 @@ def _encode_value(

def _encode_data(self, type_name: str, data: dict) -> List[int]:
values = []
for param in self.types[type_name]:
for param in self._all_types[type_name]:
encoded_value = self._encode_value(
param.type,
data[param.name],
Expand Down Expand Up @@ -273,13 +287,21 @@ def _verify_types(self):
referenced_types.update([self.domain.separator_name, self.primary_type])

basic_type_names = _get_basic_type_names(self.domain.resolved_revision)
preset_type_names = _get_preset_types(self.domain.resolved_revision).keys()

for type_name in self.types:
if not type_name:
raise ValueError("Type names cannot be empty.")

if type_name in basic_type_names:
raise ValueError(f"Reserved type name: {type_name}")
raise ValueError(
f"Types must not contain basic types. [{type_name}] was found."
)

if type_name in preset_type_names:
raise ValueError(
f"Types must not contain preset types. [{type_name}] was found."
)

if is_pointer(type_name):
raise ValueError(
Expand Down Expand Up @@ -318,7 +340,7 @@ def _get_dependencies(self, type_name: str) -> List[str]:

while to_visit:
current_type = to_visit.pop(0)
params = self.types.get(current_type, [])
params = self._all_types.get(current_type, [])

for param in params:
if isinstance(param, EnumParameter):
Expand All @@ -333,7 +355,7 @@ def _get_dependencies(self, type_name: str) -> List[str]:
]
for extracted_type in extracted_types:
if (
extracted_type in self.types
extracted_type in self._all_types
and extracted_type not in dependencies
):
dependencies.append(extracted_type)
Expand All @@ -351,11 +373,11 @@ def escape(s: str) -> str:
return s
return f'"{s}"'

if dependency not in self.types:
if dependency not in self._all_types:
raise ValueError(f"Dependency [{dependency}] is not defined in types.")

encoded_params = []
for param in self.types[dependency]:
for param in self._all_types[dependency]:
target_type = (
param.contains
if isinstance(param, EnumParameter)
Expand Down Expand Up @@ -433,10 +455,10 @@ def _get_merkle_tree_leaves_type(self, context: TypeContext) -> str:
def _resolve_type(self, context: TypeContext) -> Parameter:
parent, key = context.parent, context.key

if parent not in self.types:
if parent not in self._all_types:
raise ValueError(f"Parent {parent} is not defined in types.")

parent_type = self.types[parent]
parent_type = self._all_types[parent]

target_type = next((item for item in parent_type if item.name == key), None)
if target_type is None:
Expand Down Expand Up @@ -480,10 +502,10 @@ def _get_enum_variants(self, context: TypeContext) -> List[Parameter]:
enum_type = self._resolve_type(context)
if not isinstance(enum_type, EnumParameter):
raise ValueError(f"Type [{context.key}] is not an enum.")
if enum_type.contains not in self.types:
if enum_type.contains not in self._all_types:
raise ValueError(f"Type [{enum_type.contains}] is not defined in types")

return self.types[enum_type.contains]
return self._all_types[enum_type.contains]

def _encode_long_string(self, value: str) -> int:
byte_array_serializer = ByteArraySerializer()
Expand Down Expand Up @@ -604,20 +626,34 @@ def _get_basic_type_names(revision: Revision) -> List[str]:
BasicType.BOOL,
]

basic_types_v1 = basic_types_v0 + [
BasicType.SHORT_STRING,
BasicType.CONTRACT_ADDRESS,
BasicType.CLASS_HASH,
BasicType.U128,
BasicType.I128,
BasicType.TIMESTAMP,
BasicType.ENUM,
]
basic_types_v1 = list(BasicType)

basic_types = basic_types_v0 if revision == Revision.V0 else basic_types_v1
return [basic_type.value for basic_type in basic_types]


def _get_preset_types(
revision: Revision,
) -> Dict[str, List[StandardParameter]]:
if revision == Revision.V0:
return {}

return {
PresetType.U256.value: [
StandardParameter(name="low", type="u128"),
StandardParameter(name="high", type="u128"),
],
PresetType.TOKEN_AMOUNT.value: [
StandardParameter(name="token_address", type="ContractAddress"),
StandardParameter(name="amount", type="u256"),
],
PresetType.NFT_ID.value: [
StandardParameter(name="collection_address", type="ContractAddress"),
StandardParameter(name="token_id", type="u256"),
],
}


# pylint: disable=unused-argument
# pylint: disable=no-self-use

Expand Down
Loading

0 comments on commit 1565660

Please sign in to comment.