Skip to content

Commit

Permalink
Throw error on request if session is closed (#227)
Browse files Browse the repository at this point in the history
* throw error when connecting to WS or trying to make a request, but having a closed session

* add comments, fix import error

* fix tests

* define more detailed tests
  • Loading branch information
Snusmumr1000 authored Jan 26, 2024
1 parent 975a790 commit f44137a
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 2 deletions.
4 changes: 4 additions & 0 deletions curl_cffi/requests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ def __init__(self, msg, code=0, response=None, *args, **kwargs):

class CookieConflict(RequestsError):
pass


class SessionClosed(RequestsError):
pass
18 changes: 16 additions & 2 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .. import AsyncCurl, Curl, CurlError, CurlInfo, CurlOpt, CurlHttpVersion
from ..curl import CURL_WRITEFUNC_ERROR, CurlMime
from .cookies import Cookies, CookieTypes, CurlMorsel
from .errors import RequestsError
from .errors import RequestsError, SessionClosed
from .headers import Headers, HeaderTypes
from .models import Request, Response
from .websockets import WebSocket
Expand Down Expand Up @@ -206,6 +206,8 @@ def __init__(
self.proxies: ProxySpec = proxies or {}
self.proxy_auth = proxy_auth

self._closed = False

def _set_curl_options(
self,
curl,
Expand Down Expand Up @@ -554,6 +556,10 @@ def _parse_response(self, curl, buffer, header_buffer):

return rsp

def _check_session_closed(self):
if self._closed:
raise SessionClosed("Session is closed, cannot send request.")


# ThreadType = Literal["eventlet", "gevent", None]

Expand Down Expand Up @@ -643,6 +649,7 @@ def __exit__(self, *args):

def close(self):
"""Close the session."""
self._closed = True
self.curl.close()

@contextmanager
Expand All @@ -663,6 +670,8 @@ def ws_connect(
on_close: Optional[Callable] = None,
**kwargs,
):
self._check_session_closed()

self._set_curl_options(self.curl, "GET", url, *args, **kwargs)

# https://curl.se/docs/websocket.html
Expand Down Expand Up @@ -710,6 +719,8 @@ def request(
) -> Response:
"""Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters."""

self._check_session_closed()

# clone a new curl instance for streaming response
if stream:
c = self.curl.duphandle()
Expand Down Expand Up @@ -864,7 +875,6 @@ def __init__(
self._loop = loop
self._acurl = async_curl
self.max_clients = max_clients
self._closed = False
self.init_pool()

@property
Expand Down Expand Up @@ -936,6 +946,8 @@ async def stream(self, *args, **kwargs):
await rsp.aclose()

async def ws_connect(self, url, *args, **kwargs):
self._check_session_closed()

curl = await self.pop_curl()
# curl.debug()
self._set_curl_options(curl, "GET", url, *args, **kwargs)
Expand Down Expand Up @@ -974,6 +986,8 @@ async def request(
multipart: Optional[CurlMime] = None,
):
"""Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters."""
self._check_session_closed()

curl = await self.pop_curl()
req, buffer, header_buffer, q, header_recved, quit_now = self._set_curl_options(
curl=curl,
Expand Down
31 changes: 31 additions & 0 deletions tests/unittest/test_async_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from curl_cffi.requests import AsyncSession, RequestsError
from curl_cffi.requests.errors import SessionClosed


async def test_get(server):
Expand Down Expand Up @@ -275,6 +276,36 @@ async def test_session_too_many_headers(server):
assert headers["Foo"][0] == "2"


# https://github.com/yifeikong/curl_cffi/issues/222
async def test_closed_session_throws_error():
async with AsyncSession() as s:
pass

with pytest.raises(SessionClosed):
await s.get('https://example.com')

with pytest.raises(SessionClosed):
await s.post('https://example.com')

with pytest.raises(SessionClosed):
await s.put('https://example.com')

with pytest.raises(SessionClosed):
await s.delete('https://example.com')

with pytest.raises(SessionClosed):
await s.options('https://example.com')

with pytest.raises(SessionClosed):
await s.head('https://example.com')

with pytest.raises(SessionClosed):
await s.patch('https://example.com')

with pytest.raises(SessionClosed):
await s.ws_connect('wss://example.com')


# https://github.com/yifeikong/curl_cffi/issues/39
async def test_post_body_cleaned(server):
async with AsyncSession() as s:
Expand Down
31 changes: 31 additions & 0 deletions tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from curl_cffi import requests, CurlOpt
from curl_cffi.const import CurlECode, CurlInfo
from curl_cffi.requests.errors import SessionClosed


def test_head(server):
Expand Down Expand Up @@ -460,6 +461,36 @@ def test_session_with_all_proxies(server, proxy_server):
assert r.text == 'Hello from man in the middle'


# https://github.com/yifeikong/curl_cffi/issues/222
def test_closed_session_throws_error():
with requests.Session() as s:
pass

with pytest.raises(SessionClosed):
s.get('https://example.com')

with pytest.raises(SessionClosed):
s.post('https://example.com')

with pytest.raises(SessionClosed):
s.put('https://example.com')

with pytest.raises(SessionClosed):
s.delete('https://example.com')

with pytest.raises(SessionClosed):
s.options('https://example.com')

with pytest.raises(SessionClosed):
s.head('https://example.com')

with pytest.raises(SessionClosed):
s.patch('https://example.com')

with pytest.raises(SessionClosed):
s.ws_connect('wss://example.com')


def test_stream_iter_content(server):
with requests.Session() as s:
url = str(server.url.copy_with(path="/stream"))
Expand Down

0 comments on commit f44137a

Please sign in to comment.