-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Define Caller & SetupTeardown interfaces
- Loading branch information
1 parent
5e7c27c
commit 20f1c4d
Showing
2 changed files
with
242 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
"""``PTransform`` for reading from and writing to Web APIs. | ||
``RequestResponseIO`` minimally requires implementing the ``Caller`` interface: | ||
requests = ... | ||
responses = (requests | ||
| RequestResponseIO(MyCaller()) | ||
) | ||
""" | ||
import abc | ||
from typing import TypeVar | ||
|
||
RequestT = TypeVar('RequestT') | ||
ResponseT = TypeVar('ResponseT') | ||
|
||
|
||
class UserCodeExecutionException(Exception): | ||
"""Base class for errors related to calling Web APIs.""" | ||
|
||
|
||
class UserCodeQuotaException(UserCodeExecutionException): | ||
"""Extends ``UserCodeExecutionException`` to signal specifically that | ||
the Web API client encountered a Quota or API overuse related error. | ||
""" | ||
|
||
|
||
class UserCodeTimeoutException(UserCodeExecutionException): | ||
"""Extends ``UserCodeExecutionException`` to signal a user code timeout.""" | ||
|
||
|
||
class Caller(metaclass=abc.ABCMeta): | ||
"""Interfaces user custom code intended for API calls.""" | ||
@abc.abstractmethod | ||
def call(self, request: RequestT) -> ResponseT: | ||
"""Calls a Web API with the ``RequestT`` and returns a | ||
``ResponseT``. ``RequestResponseIO`` expects implementations of the | ||
call method to throw either a ``UserCodeExecutionException``, | ||
``UserCodeQuotaException``, or ``UserCodeTimeoutException``. | ||
""" | ||
pass | ||
|
||
|
||
class SetupTeardown(metaclass=abc.ABCMeta): | ||
"""Interfaces user custom code to set up and teardown the API clients. | ||
Called by ``RequestResponseIO`` within its DoFn's setup and teardown | ||
methods. | ||
""" | ||
@abc.abstractmethod | ||
def setup(self) -> None: | ||
"""Called during the DoFn's setup lifecycle method.""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def teardown(self) -> None: | ||
"""Called during the DoFn's teardown lifecycle method.""" | ||
pass |
168 changes: 168 additions & 0 deletions
168
sdks/python/apache_beam/io/requestresponseio_it_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import base64 | ||
import sys | ||
import unittest | ||
from dataclasses import dataclass | ||
from typing import Tuple | ||
from typing import Union | ||
|
||
import urllib3 | ||
|
||
from apache_beam.io.requestresponseio import Caller | ||
from apache_beam.io.requestresponseio import UserCodeExecutionException | ||
from apache_beam.io.requestresponseio import UserCodeQuotaException | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
|
||
_HTTP_PATH = "/v1/echo" | ||
_PAYLOAD = base64.b64encode(bytes('payload', 'utf-8')) | ||
|
||
|
||
class EchoITOptions(PipelineOptions): | ||
"""Shared options for running integration tests on a deployed | ||
``EchoServiceGrpc`` See https://github.com/apache/beam/tree/master/.test | ||
-infra/mock-apis#integration for details on how to acquire values | ||
required by ``EchoITOptions``. | ||
""" | ||
@classmethod | ||
def _add_argparse_args(cls, parser) -> None: | ||
parser.add_argument( | ||
'--httpEndpointAddress', | ||
required=True, | ||
dest='http_endpoint_address', | ||
help='The HTTP address of the Echo API endpoint; must being with ' | ||
'http(s)://') | ||
parser.add_argument( | ||
'--neverExceedQuotaId', | ||
default='echo-should-never-exceed-quota', | ||
dest='never_exceed_quota_id', | ||
help='The ID for an allocated quota that should never exceed.') | ||
parser.add_argument( | ||
'--shouldExceedQuotaId', | ||
default='echo-should-exceed-quota', | ||
dest='should_exceed_quota_id', | ||
help='The ID for an allocated quota that should exceed.') | ||
|
||
|
||
# TODO(riteshghorse,damondouglas) replace Echo(Request|Response) with proto | ||
# generated classes from .test-infra/mock-apis: | ||
@dataclass | ||
class EchoRequest: | ||
id: str | ||
payload: bytes | ||
|
||
|
||
@dataclass | ||
class EchoResponse: | ||
id: str | ||
payload: bytes | ||
|
||
|
||
class EchoHTTPCaller(Caller): | ||
"""Implements ``Caller`` to call the ``EchoServiceGrpc``'s HTTP handler. | ||
The purpose of ``EchoHTTPCaller`` is to support integration tests. | ||
""" | ||
def __init__(self, url: str): | ||
self.url = url + _HTTP_PATH | ||
|
||
def call(self, request: EchoRequest) -> EchoResponse: | ||
"""Overrides ``Caller``'s call method invoking the | ||
``EchoServiceGrpc``'s HTTP handler with an ``EchoRequest``, returning | ||
either a successful ``EchoResponse`` or throwing either a | ||
``UserCodeExecutionException``, ``UserCodeTimeoutException``, | ||
or a ``UserCodeQuotaException``. | ||
""" | ||
|
||
try: | ||
resp = urllib3.request( | ||
"POST", | ||
self.url, | ||
json={ | ||
"id": request.id, "payload": str(request.payload, 'utf-8') | ||
}, | ||
retries=False) | ||
|
||
if resp.status < 300: | ||
resp_body = resp.json() | ||
resp_id = resp_body['id'] | ||
payload = resp_body['payload'] | ||
return EchoResponse(id=resp_id, payload=bytes(payload, 'utf-8')) | ||
|
||
if resp.status == 429: # Too Many Requests | ||
raise UserCodeQuotaException(resp.reason) | ||
|
||
raise UserCodeExecutionException(resp.reason) | ||
|
||
except urllib3.exceptions.HTTPError as e: | ||
raise UserCodeExecutionException(e) | ||
|
||
|
||
class EchoHTTPCallerTestIT(unittest.TestCase): | ||
options: Union[EchoITOptions, None] = None | ||
client: Union[EchoHTTPCaller, None] = None | ||
|
||
@classmethod | ||
def setUpClass(cls) -> None: | ||
cls.options = EchoITOptions() | ||
cls.client = EchoHTTPCaller(cls.options.http_endpoint_address) | ||
|
||
def setUp(self) -> None: | ||
client, options = EchoHTTPCallerTestIT._get_client_and_options() | ||
|
||
req = EchoRequest(id=options.should_exceed_quota_id, payload=_PAYLOAD) | ||
try: | ||
# The following is needed to exceed the API | ||
client.call(req) | ||
client.call(req) | ||
client.call(req) | ||
except UserCodeExecutionException as e: | ||
if not isinstance(e, UserCodeQuotaException): | ||
raise e | ||
|
||
@classmethod | ||
def _get_client_and_options(cls) -> Tuple[EchoHTTPCaller, EchoITOptions]: | ||
assert cls.options is not None | ||
assert cls.client is not None | ||
return cls.client, cls.options | ||
|
||
def test_given_valid_request_receives_response(self): | ||
client, options = EchoHTTPCallerTestIT._get_client_and_options() | ||
|
||
req = EchoRequest(id=options.never_exceed_quota_id, payload=_PAYLOAD) | ||
|
||
response: EchoResponse = client.call(req) | ||
|
||
self.assertEqual(req.id, response.id) | ||
self.assertEqual(req.payload, response.payload) | ||
|
||
def test_given_exceeded_quota_should_raise(self): | ||
client, options = EchoHTTPCallerTestIT._get_client_and_options() | ||
|
||
req = EchoRequest(id=options.should_exceed_quota_id, payload=_PAYLOAD) | ||
|
||
self.assertRaises(UserCodeQuotaException, lambda: client.call(req)) | ||
|
||
def test_not_found_should_raise(self): | ||
client, _ = EchoHTTPCallerTestIT._get_client_and_options() | ||
|
||
req = EchoRequest(id='i-dont-exist-quota-id', payload=_PAYLOAD) | ||
self.assertRaisesRegex( | ||
UserCodeExecutionException, "Not Found", lambda: client.call(req)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main(argv=sys.argv[:1]) |