diff --git a/temporal/conversions.py b/temporal/conversions.py index 5ceae42..6e109a9 100644 --- a/temporal/conversions.py +++ b/temporal/conversions.py @@ -1,6 +1,8 @@ import json import re -from typing import List, Optional, Union, Iterable +from typing import List, Optional, Union, Iterable, Type + +import betterproto from temporal.api.common.v1 import Payload, Payloads @@ -30,10 +32,10 @@ def snake_to_title(snake_str): METADATA_ENCODING_RAW = METADATA_ENCODING_RAW_NAME.encode("utf-8") METADATA_ENCODING_JSON_NAME = "json/plain" METADATA_ENCODING_JSON = METADATA_ENCODING_JSON_NAME.encode("utf-8") - -# TODO: Implement encode/decode for these: METADATA_ENCODING_PROTOBUF_JSON_NAME = "json/protobuf" METADATA_ENCODING_PROTOBUF_JSON = METADATA_ENCODING_PROTOBUF_JSON_NAME.encode("utf-8") + +# TODO: Implement encode/decode for these: METADATA_ENCODING_PROTOBUF_NAME = "binary/protobuf" METADATA_ENCODING_PROTOBUF = METADATA_ENCODING_PROTOBUF_NAME.encode('utf-8') @@ -49,7 +51,7 @@ def encode_null(value: object) -> Optional[Payload]: # noinspection PyUnusedLocal -def decode_null(payload: Payload) -> object: +def decode_null(payload: Payload, type_hint) -> object: return None @@ -63,7 +65,7 @@ def encode_binary(value: object) -> Optional[Payload]: return None -def decode_binary(payload: Payload) -> object: +def decode_binary(payload: Payload, type_hint) -> object: return payload.data @@ -78,7 +80,7 @@ def encode_json_string(value: object) -> Payload: return p -def decode_json_string(payload: Payload) -> object: +def decode_json_string(payload: Payload, type_hint) -> object: # TODO: # mapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false); # mapper.registerModule(new JavaTimeModule()); @@ -87,9 +89,24 @@ def decode_json_string(payload: Payload) -> object: return json.loads(b) +def encode_protobuf_json(value: object) -> Payload: + if not isinstance(value, betterproto.Message): + return None + p: Payload = Payload() + p.metadata = {METADATA_ENCODING_KEY: METADATA_ENCODING_PROTOBUF_JSON} + p.data = value.to_json().encode("utf-8") + return p + + +def decode_protobuf_json(payload: Payload, type_hint: Type[betterproto.Message]) -> betterproto.Message: + b = str(payload.data, "utf-8") + return type_hint().from_json(b) + + ENCODINGS = [ encode_null, encode_binary, + encode_protobuf_json, encode_json_string ] @@ -97,6 +114,7 @@ def decode_json_string(payload: Payload) -> object: DECODINGS = { METADATA_ENCODING_NULL: decode_null, METADATA_ENCODING_RAW: decode_binary, + METADATA_ENCODING_PROTOBUF_JSON: decode_protobuf_json, METADATA_ENCODING_JSON: decode_json_string } diff --git a/temporal/converter.py b/temporal/converter.py index b1332fe..b317601 100644 --- a/temporal/converter.py +++ b/temporal/converter.py @@ -67,7 +67,7 @@ def from_payload(self, payload: Payload, type_hint: type = None) -> object: decoding = DECODINGS.get(encoding) if not decoding: raise Exception(f"Unsupported encoding: {str(encoding, 'utf-8')}") - return decoding(payload) + return decoding(payload, type_hint) DEFAULT_DATA_CONVERTER_INSTANCE = DefaultDataConverter()