From a31d1824d03df92444bb94ba60c3946926f9c410 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Tue, 22 Oct 2024 22:28:47 +0200 Subject: [PATCH 1/3] Add type annotations to `TLSUpgradeProto` --- asyncpg/connect_utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 4890d007..6cefc020 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -764,14 +764,21 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, class TLSUpgradeProto(asyncio.Protocol): - def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): + def __init__( + self, + loop: asyncio.AbstractEventLoop, + host: str, + port: int, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool, + ) -> None: self.on_data = _create_future(loop) self.host = host self.port = port self.ssl_context = ssl_context self.ssl_is_advisory = ssl_is_advisory - def data_received(self, data): + def data_received(self, data: bytes) -> None: if data == b'S': self.on_data.set_result(True) elif (self.ssl_is_advisory and @@ -789,7 +796,7 @@ def data_received(self, data): 'rejected SSL upgrade'.format( host=self.host, port=self.port))) - def connection_lost(self, exc): + def connection_lost(self, exc: typing.Optional[Exception]) -> None: if not self.on_data.done(): if exc is None: exc = ConnectionError('unexpected connection_lost() call') From 76105cc6de9ce6d074540c9b5c82c16c953fdbed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:19:09 +0200 Subject: [PATCH 2/3] Add type annotations to `_create_ssl_connection` --- asyncpg/connect_utils.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 6cefc020..b1451f3f 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -7,6 +7,7 @@ import asyncio import collections +from collections.abc import Callable import enum import functools import getpass @@ -803,8 +804,23 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None: self.on_data.set_exception(exc) -async def _create_ssl_connection(protocol_factory, host, port, *, - loop, ssl_context, ssl_is_advisory=False): +_ProctolFactoryR = typing.TypeVar( + "_ProctolFactoryR", bound=asyncio.protocols.Protocol +) + + +async def _create_ssl_connection( + # TODO: The return type is a specific combination of subclasses of + # asyncio.protocols.Protocol that we can't express. For now, having the + # return type be dependent on signature of the factory is an improvement + protocol_factory: "Callable[[], _ProctolFactoryR]", + host: str, + port: int, + *, + loop: asyncio.AbstractEventLoop, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool = False, +) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]: tr, pr = await loop.create_connection( lambda: TLSUpgradeProto(loop, host, port, @@ -824,6 +840,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *, try: new_tr = await loop.start_tls( tr, pr, ssl_context, server_hostname=host) + assert new_tr is not None except (Exception, asyncio.CancelledError): tr.close() raise From 69054972d68ce0c948e4619f0134f4acb08ee60b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Thu, 24 Oct 2024 08:18:40 +0200 Subject: [PATCH 3/3] Import `__future__` and unstringify annotation --- asyncpg/connect_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index b1451f3f..c65f68a6 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -4,6 +4,7 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio import collections @@ -813,7 +814,7 @@ async def _create_ssl_connection( # TODO: The return type is a specific combination of subclasses of # asyncio.protocols.Protocol that we can't express. For now, having the # return type be dependent on signature of the factory is an improvement - protocol_factory: "Callable[[], _ProctolFactoryR]", + protocol_factory: Callable[[], _ProctolFactoryR], host: str, port: int, *,