Skip to content

Commit

Permalink
Disable multipart encoding in SyncApiProvider (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw authored Oct 11, 2023
1 parent 1df7540 commit d7a75d5
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 16 deletions.
6 changes: 4 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
0.6.9 (unreleased)
------------------

- Nothing changed yet.
- Disable the default multipart encoding in `SyncApiProvider`.

- Added `file` parameter to `ApiProvider` to upload files (async is a TODO).


0.6.8 (2023-10-10)
Expand All @@ -16,7 +18,7 @@
0.6.7 (2023-10-09)
------------------

- Adapt call signature of the `fetch_token` callable in `ApiProvicer`.
- Adapt call signature of the `fetch_token` callable in `ApiProvider`.

- Add `clean_python.oauth.client_credentials`.

Expand Down
30 changes: 27 additions & 3 deletions clean_python/api_client/api_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import re
from http import HTTPStatus
from io import BytesIO
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Dict
Expand All @@ -13,13 +15,15 @@
from aiohttp import ClientResponse
from aiohttp import ClientSession
from pydantic import AnyHttpUrl
from pydantic import field_validator

from clean_python import Json
from clean_python import ValueObject

from .exceptions import ApiException
from .response import Response

__all__ = ["ApiProvider"]
__all__ = ["ApiProvider", "FileFormPost"]


RETRY_STATUSES = frozenset({413, 429, 503}) # like in urllib3
Expand Down Expand Up @@ -57,6 +61,21 @@ def add_query_params(url: str, params: Optional[Json]) -> str:
return url + "?" + urlencode(params, doseq=True)


class FileFormPost(ValueObject):
file_name: str
file: Any # typing of BinaryIO / BytesIO is hard!
field_name: str = "file"
content_type: str = "application/octet-stream"

@field_validator("file")
@classmethod
def validate_file(cls, v):
if isinstance(v, bytes):
return BytesIO(v)
assert hasattr(v, "read") # poor-mans BinaryIO validation
return v


class ApiProvider:
"""Basic JSON API provider with retry policy and bearer tokens.
Expand Down Expand Up @@ -94,8 +113,11 @@ async def _request_with_retry(
params: Optional[Json],
json: Optional[Json],
fields: Optional[Json],
file: Optional[FileFormPost],
timeout: float,
) -> ClientResponse:
if file is not None:
raise NotImplementedError("ApiProvider doesn't yet support file uploads")
request_kwargs = {
"method": method,
"url": add_query_params(
Expand Down Expand Up @@ -130,10 +152,11 @@ async def request(
params: Optional[Json] = None,
json: Optional[Json] = None,
fields: Optional[Json] = None,
file: Optional[FileFormPost] = None,
timeout: float = 5.0,
) -> Optional[Json]:
response = await self._request_with_retry(
method, path, params, json, fields, timeout
method, path, params, json, fields, file, timeout
)
status = HTTPStatus(response.status)
content_type = response.headers.get("Content-Type")
Expand All @@ -156,10 +179,11 @@ async def request_raw(
params: Optional[Json] = None,
json: Optional[Json] = None,
fields: Optional[Json] = None,
file: Optional[FileFormPost] = None,
timeout: float = 5.0,
) -> Response:
response = await self._request_with_retry(
method, path, params, json, fields, timeout
method, path, params, json, fields, file, timeout
)
return Response(
status=response.status,
Expand Down
24 changes: 21 additions & 3 deletions clean_python/api_client/sync_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from clean_python import Json

from .api_provider import add_query_params
from .api_provider import FileFormPost
from .api_provider import is_json_content_type
from .api_provider import is_success
from .api_provider import join
Expand Down Expand Up @@ -55,6 +56,7 @@ def _request(
params: Optional[Json],
json: Optional[Json],
fields: Optional[Json],
file: Optional[FileFormPost],
timeout: float,
):
headers = {}
Expand All @@ -68,11 +70,25 @@ def _request(
# for urllib3<2, we dump json ourselves
if json is not None and fields is not None:
raise ValueError("Cannot both specify 'json' and 'fields'")
elif json is not None and file is not None:
raise ValueError("Cannot both specify 'json' and 'file'")
elif json is not None:
request_kwargs["body"] = json_lib.dumps(json).encode()
headers["Content-Type"] = "application/json"
elif fields is not None:
elif fields is not None and file is None:
request_kwargs["fields"] = fields
request_kwargs["encode_multipart"] = False
elif file is not None:
request_kwargs["fields"] = {
file.field_name: (
file.file_name,
file.file.read(),
file.content_type,
),
**(fields or {}),
}
request_kwargs["encode_multipart"] = True

headers.update(self._fetch_token())
return self._pool.request(headers=headers, **request_kwargs)

Expand All @@ -83,9 +99,10 @@ def request(
params: Optional[Json] = None,
json: Optional[Json] = None,
fields: Optional[Json] = None,
file: Optional[FileFormPost] = None,
timeout: float = 5.0,
) -> Optional[Json]:
response = self._request(method, path, params, json, fields, timeout)
response = self._request(method, path, params, json, fields, file, timeout)
status = HTTPStatus(response.status)
content_type = response.headers.get("Content-Type")
if status is HTTPStatus.NO_CONTENT:
Expand All @@ -107,9 +124,10 @@ def request_raw(
params: Optional[Json] = None,
json: Optional[Json] = None,
fields: Optional[Json] = None,
file: Optional[FileFormPost] = None,
timeout: float = 5.0,
) -> Response:
response = self._request(method, path, params, json, fields, timeout)
response = self._request(method, path, params, json, fields, file, timeout)
return Response(
status=response.status,
data=response.data,
Expand Down
24 changes: 18 additions & 6 deletions integration_tests/fastapi_example/presentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from fastapi import Depends
from fastapi import Form
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -76,8 +77,8 @@ async def form(self, name: str = Form()):
return {"name": name}

@post("/file")
async def file(self, file: UploadFile):
return {file.filename: (await file.read()).decode()}
async def file(self, file: UploadFile, description: str = Form()):
return {file.filename: (await file.read()).decode(), "description": description}

@put("/urlencode/{name}", response_model=Author)
async def urlencode(self, name: str):
Expand All @@ -86,19 +87,30 @@ async def urlencode(self, name: str):
@post("/token")
def token(
self,
request: Request,
grant_type: str = Form(),
scope: str = Form(),
credentials: HTTPBasicCredentials = Depends(basic),
):
"""For testing client credentials grant"""
if request.headers["Content-Type"] != "application/x-www-form-urlencoded":
return Response(status_code=HTTPStatus.METHOD_NOT_ALLOWED)
if grant_type != "client_credentials":
return JSONResponse({"error": "invalid_grant"})
return JSONResponse(
{"error": "invalid_grant"}, status_code=HTTPStatus.BAD_REQUEST
)
if credentials.username != "testclient":
return JSONResponse({"error": "invalid_client"})
return JSONResponse(
{"error": "invalid_client"}, status_code=HTTPStatus.BAD_REQUEST
)
if credentials.password != "supersecret":
return JSONResponse({"error": "invalid_client"})
return JSONResponse(
{"error": "invalid_client"}, status_code=HTTPStatus.BAD_REQUEST
)
if scope != "all":
return JSONResponse({"error": "invalid_grant"})
return JSONResponse(
{"error": "invalid_grant"}, status_code=HTTPStatus.BAD_REQUEST
)
claims = {"user": "foo", "exp": int(time.time()) + 3600}
payload = base64.b64encode(json.dumps(claims).encode()).decode()
return {
Expand Down
10 changes: 8 additions & 2 deletions integration_tests/test_int_sync_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from clean_python import ctx
from clean_python import Tenant
from clean_python.api_client import ApiException
from clean_python.api_client import FileFormPost
from clean_python.api_client import SyncApiProvider


Expand Down Expand Up @@ -46,10 +47,15 @@ def test_request_form_body(provider: SyncApiProvider):


def test_request_form_file(provider: SyncApiProvider):
response = provider.request("POST", "v1/file", fields={"file": ("x.txt", b"foo")})
response = provider.request(
"POST",
"v1/file",
fields={"description": "bla"},
file=FileFormPost(file_name="x.txt", file=b"foo"),
)

assert isinstance(response, dict)
assert response["x.txt"] == "foo"
assert response == {"x.txt": "foo", "description": "bla"}


@pytest.fixture
Expand Down
44 changes: 44 additions & 0 deletions tests/api_client/test_sync_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from clean_python import ctx
from clean_python import Tenant
from clean_python.api_client import ApiException
from clean_python.api_client import FileFormPost
from clean_python.api_client import SyncApiProvider

MODULE = "clean_python.api_client.sync_api_provider"
Expand Down Expand Up @@ -159,3 +160,46 @@ def test_trailing_slash(api_provider: SyncApiProvider, path, trailing_slash, exp
api_provider._pool.request.call_args[1]["url"]
== "http://testserver/foo/" + expected
)


def test_post_file(api_provider: SyncApiProvider):
api_provider.request(
"POST",
"bar",
file=FileFormPost(file_name="test.zip", file=b"foo", field_name="x"),
)

assert api_provider._pool.request.call_count == 1

assert api_provider._pool.request.call_args[1] == dict(
method="POST",
url="http://testserver/foo/bar",
fields={"x": ("test.zip", b"foo", "application/octet-stream")},
headers={
"Authorization": "Bearer tenant-2",
},
timeout=5.0,
encode_multipart=True,
)


def test_post_file_with_fields(api_provider: SyncApiProvider):
api_provider.request(
"POST",
"bar",
fields={"a": "b"},
file=FileFormPost(file_name="test.zip", file=b"foo", field_name="x"),
)

assert api_provider._pool.request.call_count == 1

assert api_provider._pool.request.call_args[1] == dict(
method="POST",
url="http://testserver/foo/bar",
fields={"a": "b", "x": ("test.zip", b"foo", "application/octet-stream")},
headers={
"Authorization": "Bearer tenant-2",
},
timeout=5.0,
encode_multipart=True,
)

0 comments on commit d7a75d5

Please sign in to comment.