From fbbf19357bc31ae901004924138cb9aeadfcc37a Mon Sep 17 00:00:00 2001 From: David Sanders Date: Sat, 26 May 2018 00:13:27 -0600 Subject: [PATCH 1/2] Fixes to handle simple coder callables --- eth_abi/abi.py | 9 +++++++-- eth_abi/base.py | 2 +- eth_abi/decoding.py | 4 ++-- eth_abi/encoding.py | 9 ++++++--- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/eth_abi/abi.py b/eth_abi/abi.py index 5efa9980..964add1c 100644 --- a/eth_abi/abi.py +++ b/eth_abi/abi.py @@ -59,8 +59,13 @@ def is_encodable(typ, arg): encoder.validate_value(arg) except EncodingError: return False - else: - return True + except AttributeError: + try: + encoder(arg) + except EncodingError: + return False + + return True # Decodes a single base datum diff --git a/eth_abi/base.py b/eth_abi/base.py index 1d55cd1f..7e0cef15 100644 --- a/eth_abi/base.py +++ b/eth_abi/base.py @@ -144,4 +144,4 @@ def from_type_str(cls, type_str, registry): # pragma: no cover Used by ``ABIRegistry`` to get an appropriate encoder or decoder instance for the given type string and type registry. """ - raise NotImplementedError('Must implement `from_type_str`') + raise cls() diff --git a/eth_abi/decoding.py b/eth_abi/decoding.py index 76408c96..cf3e5ee5 100644 --- a/eth_abi/decoding.py +++ b/eth_abi/decoding.py @@ -145,11 +145,11 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.decoders = tuple( - HeadTailDecoder(tail_decoder=d) if d.is_dynamic else d + HeadTailDecoder(tail_decoder=d) if getattr(d, 'is_dynamic', False) else d for d in self.decoders ) - self.is_dynamic = any(d.is_dynamic for d in self.decoders) + self.is_dynamic = any(getattr(d, 'is_dynamic', False) for d in self.decoders) def validate(self): super().validate() diff --git a/eth_abi/encoding.py b/eth_abi/encoding.py index 8091d12f..0cd8134f 100644 --- a/eth_abi/encoding.py +++ b/eth_abi/encoding.py @@ -71,7 +71,7 @@ class TupleEncoder(BaseEncoder): def __init__(self, **kwargs): super().__init__(**kwargs) - self.is_dynamic = any(e.is_dynamic for e in self.encoders) + self.is_dynamic = any(getattr(e, 'is_dynamic', False) for e in self.encoders) def validate(self): super().validate() @@ -95,7 +95,10 @@ def validate_value(self, value): ) for item, encoder in zip(value, self.encoders): - encoder.validate_value(item) + try: + encoder.validate_value(item) + except AttributeError: + encoder(item) def encode(self, values): self.validate_value(values) @@ -103,7 +106,7 @@ def encode(self, values): raw_head_chunks = [] tail_chunks = [] for value, encoder in zip(values, self.encoders): - if encoder.is_dynamic: + if getattr(encoder, 'is_dynamic', False): raw_head_chunks.append(None) tail_chunks.append(encoder(value)) else: From d634eb127925882d3da2162bdd9cb03ba109aede Mon Sep 17 00:00:00 2001 From: David Sanders Date: Mon, 4 Jun 2018 23:39:28 -0600 Subject: [PATCH 2/2] Add tests for custom registrations --- tests/test_integration/__init__.py | 0 .../test_custom_registrations.py | 99 +++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 tests/test_integration/__init__.py create mode 100644 tests/test_integration/test_custom_registrations.py diff --git a/tests/test_integration/__init__.py b/tests/test_integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_integration/test_custom_registrations.py b/tests/test_integration/test_custom_registrations.py new file mode 100644 index 00000000..85ad6b5f --- /dev/null +++ b/tests/test_integration/test_custom_registrations.py @@ -0,0 +1,99 @@ +from eth_abi.abi import ( + decode_single, + encode_single, +) +from eth_abi.decoding import ( + BaseDecoder, +) +from eth_abi.encoding import ( + BaseEncoder, +) +from eth_abi.exceptions import ( + DecodingError, + EncodingError, +) +from eth_abi.registry import ( + registry, +) + +NULL_ENCODING = b'\x00' * 32 + + +def encode_null(x): + if x is not None: + raise EncodingError('Unsupported value') + + return NULL_ENCODING + + +def decode_null(stream): + if stream.read(32) != NULL_ENCODING: + raise DecodingError('Not enough data or wrong data') + + return None + + +class EncodeNull(BaseEncoder): + word_width = None + + @classmethod + def from_type_str(cls, type_str, registry): + word_width = int(type_str[4:]) + return cls(word_width=word_width) + + def encode(self, value): + self.validate_value(value) + return NULL_ENCODING * self.word_width + + def validate_value(self, value): + if value is not None: + raise EncodingError('Unsupported value') + + +class DecodeNull(BaseDecoder): + word_width = None + + @classmethod + def from_type_str(cls, type_str, registry): + word_width = int(type_str[4:]) + return cls(word_width=word_width) + + def decode(self, stream): + byts = stream.read(32 * self.word_width) + if byts != NULL_ENCODING * self.word_width: + raise DecodingError('Not enough data or wrong data') + + return None + + +def test_register_and_use_callables(): + registry.register('null', encode_null, decode_null) + + assert encode_single('null', None) == NULL_ENCODING + assert decode_single('null', NULL_ENCODING) is None + + encoded_tuple = encode_single('(int,null)', (1, None)) + + assert encoded_tuple == b'\x00' * 31 + b'\x01' + NULL_ENCODING + assert decode_single('(int,null)', encoded_tuple) == (1, None) + + registry.unregister('null') + + +def test_register_and_use_coder_classes(): + registry.register( + lambda x: x.startswith('null'), + EncodeNull, + DecodeNull, + label='null', + ) + + assert encode_single('null2', None) == NULL_ENCODING * 2 + assert decode_single('null2', NULL_ENCODING * 2) is None + + encoded_tuple = encode_single('(int,null2)', (1, None)) + + assert encoded_tuple == b'\x00' * 31 + b'\x01' + NULL_ENCODING * 2 + assert decode_single('(int,null2)', encoded_tuple) == (1, None) + + registry.unregister('null')