Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable multipart encoding in SyncApiProvider #25

Merged
merged 5 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
0.6.9 (unreleased)
------------------

- Nothing changed yet.
- Disable the default multipart encoding in `SyncApiProvider`.

- Added `file` parameter to `ApiProvider` to upload files (async is a TODO).


0.6.8 (2023-10-10)
Expand All @@ -16,7 +18,7 @@
0.6.7 (2023-10-09)
------------------

- Adapt call signature of the `fetch_token` callable in `ApiProvicer`.
- Adapt call signature of the `fetch_token` callable in `ApiProvider`.

- Add `clean_python.oauth.client_credentials`.

Expand Down
30 changes: 27 additions & 3 deletions clean_python/api_client/api_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import re
from http import HTTPStatus
from io import BytesIO
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Dict
Expand All @@ -13,13 +15,15 @@
from aiohttp import ClientResponse
from aiohttp import ClientSession
from pydantic import AnyHttpUrl
from pydantic import field_validator

from clean_python import Json
from clean_python import ValueObject

from .exceptions import ApiException
from .response import Response

__all__ = ["ApiProvider"]
__all__ = ["ApiProvider", "FileFormPost"]


RETRY_STATUSES = frozenset({413, 429, 503}) # like in urllib3
Expand Down Expand Up @@ -57,6 +61,21 @@ def add_query_params(url: str, params: Optional[Json]) -> str:
return url + "?" + urlencode(params, doseq=True)


class FileFormPost(ValueObject):
file_name: str
file: Any # typing of BinaryIO / BytesIO is hard!
field_name: str = "file"
content_type: str = "application/octet-stream"

@field_validator("file")
@classmethod
def validate_file(cls, v):
if isinstance(v, bytes):
return BytesIO(v)
assert hasattr(v, "read") # poor-mans BinaryIO validation
return v


class ApiProvider:
"""Basic JSON API provider with retry policy and bearer tokens.

Expand Down Expand Up @@ -94,8 +113,11 @@ async def _request_with_retry(
params: Optional[Json],
json: Optional[Json],
fields: Optional[Json],
file: Optional[FileFormPost],
timeout: float,
) -> ClientResponse:
if file is not None:
raise NotImplementedError("ApiProvider doesn't yet support file uploads")
request_kwargs = {
"method": method,
"url": add_query_params(
Expand Down Expand Up @@ -130,10 +152,11 @@ async def request(
params: Optional[Json] = None,
json: Optional[Json] = None,
fields: Optional[Json] = None,
file: Optional[FileFormPost] = None,
timeout: float = 5.0,
) -> Optional[Json]:
response = await self._request_with_retry(
method, path, params, json, fields, timeout
method, path, params, json, fields, file, timeout
)
status = HTTPStatus(response.status)
content_type = response.headers.get("Content-Type")
Expand All @@ -156,10 +179,11 @@ async def request_raw(
params: Optional[Json] = None,
json: Optional[Json] = None,
fields: Optional[Json] = None,
file: Optional[FileFormPost] = None,
timeout: float = 5.0,
) -> Response:
response = await self._request_with_retry(
method, path, params, json, fields, timeout
method, path, params, json, fields, file, timeout
)
return Response(
status=response.status,
Expand Down
24 changes: 21 additions & 3 deletions clean_python/api_client/sync_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from clean_python import Json

from .api_provider import add_query_params
from .api_provider import FileFormPost
from .api_provider import is_json_content_type
from .api_provider import is_success
from .api_provider import join
Expand Down Expand Up @@ -55,6 +56,7 @@ def _request(
params: Optional[Json],
json: Optional[Json],
fields: Optional[Json],
file: Optional[FileFormPost],
timeout: float,
):
headers = {}
Expand All @@ -68,11 +70,25 @@ def _request(
# for urllib3<2, we dump json ourselves
if json is not None and fields is not None:
raise ValueError("Cannot both specify 'json' and 'fields'")
elif json is not None and file is not None:
raise ValueError("Cannot both specify 'json' and 'file'")
elif json is not None:
request_kwargs["body"] = json_lib.dumps(json).encode()
headers["Content-Type"] = "application/json"
elif fields is not None:
elif fields is not None and file is None:
request_kwargs["fields"] = fields
request_kwargs["encode_multipart"] = False
elif file is not None:
request_kwargs["fields"] = {
file.field_name: (
file.file_name,
file.file.read(),
file.content_type,
),
**(fields or {}),
}
request_kwargs["encode_multipart"] = True

headers.update(self._fetch_token())
return self._pool.request(headers=headers, **request_kwargs)

Expand All @@ -83,9 +99,10 @@ def request(
params: Optional[Json] = None,
json: Optional[Json] = None,
fields: Optional[Json] = None,
file: Optional[FileFormPost] = None,
timeout: float = 5.0,
) -> Optional[Json]:
response = self._request(method, path, params, json, fields, timeout)
response = self._request(method, path, params, json, fields, file, timeout)
status = HTTPStatus(response.status)
content_type = response.headers.get("Content-Type")
if status is HTTPStatus.NO_CONTENT:
Expand All @@ -107,9 +124,10 @@ def request_raw(
params: Optional[Json] = None,
json: Optional[Json] = None,
fields: Optional[Json] = None,
file: Optional[FileFormPost] = None,
timeout: float = 5.0,
) -> Response:
response = self._request(method, path, params, json, fields, timeout)
response = self._request(method, path, params, json, fields, file, timeout)
return Response(
status=response.status,
data=response.data,
Expand Down
24 changes: 18 additions & 6 deletions integration_tests/fastapi_example/presentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from fastapi import Depends
from fastapi import Form
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -76,8 +77,8 @@ async def form(self, name: str = Form()):
return {"name": name}

@post("/file")
async def file(self, file: UploadFile):
return {file.filename: (await file.read()).decode()}
async def file(self, file: UploadFile, description: str = Form()):
return {file.filename: (await file.read()).decode(), "description": description}

@put("/urlencode/{name}", response_model=Author)
async def urlencode(self, name: str):
Expand All @@ -86,19 +87,30 @@ async def urlencode(self, name: str):
@post("/token")
def token(
self,
request: Request,
grant_type: str = Form(),
scope: str = Form(),
credentials: HTTPBasicCredentials = Depends(basic),
):
"""For testing client credentials grant"""
if request.headers["Content-Type"] != "application/x-www-form-urlencoded":
return Response(status_code=HTTPStatus.METHOD_NOT_ALLOWED)
if grant_type != "client_credentials":
return JSONResponse({"error": "invalid_grant"})
return JSONResponse(
{"error": "invalid_grant"}, status_code=HTTPStatus.BAD_REQUEST
)
if credentials.username != "testclient":
return JSONResponse({"error": "invalid_client"})
return JSONResponse(
{"error": "invalid_client"}, status_code=HTTPStatus.BAD_REQUEST
)
if credentials.password != "supersecret":
return JSONResponse({"error": "invalid_client"})
return JSONResponse(
{"error": "invalid_client"}, status_code=HTTPStatus.BAD_REQUEST
)
if scope != "all":
return JSONResponse({"error": "invalid_grant"})
return JSONResponse(
{"error": "invalid_grant"}, status_code=HTTPStatus.BAD_REQUEST
)
claims = {"user": "foo", "exp": int(time.time()) + 3600}
payload = base64.b64encode(json.dumps(claims).encode()).decode()
return {
Expand Down
10 changes: 8 additions & 2 deletions integration_tests/test_int_sync_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from clean_python import ctx
from clean_python import Tenant
from clean_python.api_client import ApiException
from clean_python.api_client import FileFormPost
from clean_python.api_client import SyncApiProvider


Expand Down Expand Up @@ -46,10 +47,15 @@ def test_request_form_body(provider: SyncApiProvider):


def test_request_form_file(provider: SyncApiProvider):
response = provider.request("POST", "v1/file", fields={"file": ("x.txt", b"foo")})
response = provider.request(
"POST",
"v1/file",
fields={"description": "bla"},
file=FileFormPost(file_name="x.txt", file=b"foo"),
)

assert isinstance(response, dict)
assert response["x.txt"] == "foo"
assert response == {"x.txt": "foo", "description": "bla"}


@pytest.fixture
Expand Down
44 changes: 44 additions & 0 deletions tests/api_client/test_sync_api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from clean_python import ctx
from clean_python import Tenant
from clean_python.api_client import ApiException
from clean_python.api_client import FileFormPost
from clean_python.api_client import SyncApiProvider

MODULE = "clean_python.api_client.sync_api_provider"
Expand Down Expand Up @@ -159,3 +160,46 @@ def test_trailing_slash(api_provider: SyncApiProvider, path, trailing_slash, exp
api_provider._pool.request.call_args[1]["url"]
== "http://testserver/foo/" + expected
)


def test_post_file(api_provider: SyncApiProvider):
api_provider.request(
"POST",
"bar",
file=FileFormPost(file_name="test.zip", file=b"foo", field_name="x"),
)

assert api_provider._pool.request.call_count == 1

assert api_provider._pool.request.call_args[1] == dict(
method="POST",
url="http://testserver/foo/bar",
fields={"x": ("test.zip", b"foo", "application/octet-stream")},
headers={
"Authorization": "Bearer tenant-2",
},
timeout=5.0,
encode_multipart=True,
)


def test_post_file_with_fields(api_provider: SyncApiProvider):
api_provider.request(
"POST",
"bar",
fields={"a": "b"},
file=FileFormPost(file_name="test.zip", file=b"foo", field_name="x"),
)

assert api_provider._pool.request.call_count == 1

assert api_provider._pool.request.call_args[1] == dict(
method="POST",
url="http://testserver/foo/bar",
fields={"a": "b", "x": ("test.zip", b"foo", "application/octet-stream")},
headers={
"Authorization": "Bearer tenant-2",
},
timeout=5.0,
encode_multipart=True,
)