diff --git a/docs/docs/middlewares.md b/docs/docs/middlewares.md index 8d87a3db..487966a5 100644 --- a/docs/docs/middlewares.md +++ b/docs/docs/middlewares.md @@ -16,16 +16,16 @@ And you can write your own custom middlewares too ## Structure of middlewares `MIDDLEWARES` itself is a `list` of `tuples` which each `tuple` is like below: -(`Address of Middleware Class`, `kwargs as dict`) +(`Dotted Address of The Middleware Class`, `kwargs as dict`) ## Database Middleware -This middleware will create a `db` connection that uses in `ODM` or you can use it manually from: +This middleware will create a `db` connection which is used in `ODM` and you can use it manually too, it gives you a database connection: ```python from panther.db.connection import db ``` -We only support 2 database: `PantherDB` & `MongoDB` +We only support 2 database for now: `PantherDB` & `MongoDB` - Address of Middleware: `panther.middlewares.db.DatabaseMiddleware` - kwargs: @@ -36,13 +36,13 @@ We only support 2 database: `PantherDB` & `MongoDB` - Example of `PantherDB` (`Built-in Local Storage`): ```python MIDDLEWARES = [ - ('panther.middlewares.db.DatabaseMiddleware', {'url': f'pantherdb://{BASE_DIR}/{DB_NAME}.pdb'}), + ('panther.middlewares.db.DatabaseMiddleware', {'url': 'pantherdb://project_directory/database.pdb'}), ] ``` - Example of `MongoDB`: ```python MIDDLEWARES = [ - ('panther.middlewares.db.DatabaseMiddleware', {'url': f'mongodb://{DB_HOST}:27017/{DB_NAME}'}), + ('panther.middlewares.db.DatabaseMiddleware', {'url': 'mongodb://127.0.0.1:27017/example'}), ] ``` @@ -61,43 +61,88 @@ We only support 2 database: `PantherDB` & `MongoDB` ``` ## Custom Middleware -Write a `class` and inherit from -```python -from panther.middlewares.base import BaseMiddleware -``` +### Middleware Types + We have 3 type of Middlewares, make sure that you are inheriting from the correct one: + - `Base Middleware`: which is used for both `websocket` and `http` requests + - `HTTP Middleware`: which is only used for `http` requests + - `Websocket Middleware`: which is only used for `websocket` requests + +### Write Custom Middleware + - Write a `class` and inherit from one of the classes below + ```python + # For HTTP Requests + from panther.middlewares.base import HTTPMiddleware + + # For Websocket Requests + from panther.middlewares.base import WebsocketMiddleware + + # For Both HTTP and Websocket Requests + from panther.middlewares.base import BaseMiddleware + ``` -Then you can write your custom `before()` and `after()` methods + - Then you can write your custom `before()` and `after()` methods -- The `methods` should be `async` -- `before()` should have `request` parameter -- `after()` should have `response` parameter -- overwriting the `before()` and `after()` are optional -- The `methods` can get `kwargs` from their `__init__` + - The `methods` should be `async` + - `before()` should have `request` parameter + - `after()` should have `response` parameter + - overwriting the `before()` and `after()` are optional + - The `methods` can get `kwargs` from their `__init__` -### Custom Middleware Example -core/middlewares.py -```python -from panther.request import Request -from panther.response import Response -from panther.middlewares.base import BaseMiddleware +### Custom HTTP Middleware Example +- **core/middlewares.py** + ```python + from panther.middlewares.base import HTTPMiddleware + from panther.request import Request + from panther.response import Response -class CustomMiddleware(BaseMiddleware): + class CustomMiddleware(HTTPMiddleware): - def __init__(self, something): - self.something = something + def __init__(self, something): + self.something = something - async def before(self, request: Request) -> Request: - print('Before Endpoint', self.something) - return request + async def before(self, request: Request) -> Request: + print('Before Endpoint', self.something) + return request - async def after(self, response: Response) -> Response: - print('After Endpoint', self.something) - return response -``` -core/configs.py -```python - MIDDLEWARES = [ - ('core.middlewares.CustomMiddleware', {'something': 'hello-world'}), - ] -``` \ No newline at end of file + async def after(self, response: Response) -> Response: + print('After Endpoint', self.something) + return response + ``` + +- **core/configs.py** + ```python + MIDDLEWARES = [ + ('core.middlewares.CustomMiddleware', {'something': 'hello-world'}), + ] + ``` + +### Custom HTTP + Websocket Middleware Example +- **core/middlewares.py** + ```python + from panther.middlewares.base import BaseMiddleware + from panther.request import Request + from panther.response import Response + from panther.websocket import GenericWebsocket + + + class SayHiMiddleware(BaseMiddleware): + + def __init__(self, name): + self.name = name + + async def before(self, request: Request | GenericWebsocket) -> Request | GenericWebsocket: + print('Hello ', self.name) + return request + + async def after(self, response: Response | GenericWebsocket) -> Response | GenericWebsocket: + print('Goodbye ', self.name) + return response + ``` + +- **core/configs.py** + ```python + MIDDLEWARES = [ + ('core.middlewares.SayHiMiddleware', {'name': 'Ali Rn'}), + ] + ``` \ No newline at end of file diff --git a/docs/docs/release_notes.md b/docs/docs/release_notes.md index 346b25ba..e5db7b1d 100644 --- a/docs/docs/release_notes.md +++ b/docs/docs/release_notes.md @@ -1,3 +1,6 @@ +### 3.4.0 +- Support `WebsocketMiddleware` + ### 3.3.2 - Add `content-length` to response header diff --git a/panther/__init__.py b/panther/__init__.py index fa5e53cf..da275868 100644 --- a/panther/__init__.py +++ b/panther/__init__.py @@ -1,6 +1,6 @@ from panther.main import Panther # noqa: F401 -__version__ = '3.3.2' +__version__ = '3.4.0' def version(): diff --git a/panther/_load_configs.py b/panther/_load_configs.py index 06a8d0c7..7fcd9ca9 100644 --- a/panther/_load_configs.py +++ b/panther/_load_configs.py @@ -11,6 +11,7 @@ from panther._utils import import_class from panther.configs import JWTConfig, config from panther.exceptions import PantherException +from panther.middlewares.base import WebsocketMiddleware, HTTPMiddleware from panther.routings import finalize_urls, flatten_urls from panther.throttling import Throttling @@ -73,11 +74,13 @@ def load_default_cache_exp(configs: dict, /) -> timedelta | None: return configs.get('DEFAULT_CACHE_EXP', config['default_cache_exp']) -def load_middlewares(configs: dict, /) -> list: - """Collect The Middlewares & Set db_engine If One Of Middlewares Was For DB""" +def load_middlewares(configs: dict, /) -> dict: + """ + Collect The Middlewares & Set db_engine If One Of Middlewares Was For DB + And Return a dict with two list, http and ws middlewares""" from panther.middlewares import BaseMiddleware - middlewares = [] + middlewares = {'http': [], 'ws': []} for middleware in configs.get('MIDDLEWARES') or []: if not isinstance(middleware, list | tuple): @@ -103,7 +106,11 @@ def load_middlewares(configs: dict, /) -> list: if issubclass(Middleware, BaseMiddleware) is False: raise _exception_handler(field='MIDDLEWARES', error='is not a sub class of BaseMiddleware') - middlewares.append(Middleware(**data)) # noqa: Py Argument List + middleware_instance = Middleware(**data) + if isinstance(middleware_instance, BaseMiddleware | HTTPMiddleware): + middlewares['http'].append(middleware_instance) + if isinstance(middleware_instance, BaseMiddleware | WebsocketMiddleware): + middlewares['ws'].append(middleware_instance) return middlewares diff --git a/panther/_utils.py b/panther/_utils.py index 447a282e..c397dab0 100644 --- a/panther/_utils.py +++ b/panther/_utils.py @@ -2,6 +2,8 @@ import importlib import logging import re +import subprocess +import types from collections.abc import Callable from traceback import TracebackException from uuid import uuid4 @@ -9,9 +11,9 @@ import orjson as json from panther import status +from panther.exceptions import PantherException from panther.file_handler import File - logger = logging.getLogger('panther') @@ -150,3 +152,29 @@ def clean_traceback_message(exception: Exception) -> str: tb.stack.remove(t) _traceback = list(tb.format(chain=False)) return exception if len(_traceback) == 1 else f'{exception}\n' + ''.join(_traceback) + + +def reformat_code(base_dir): + try: + subprocess.run(['ruff', 'format', base_dir]) + subprocess.run(['ruff', 'check', '--select', 'I', '--fix', base_dir]) + except FileNotFoundError: + raise PantherException("Module 'ruff' not found, Hint: `pip install ruff`") + + +def check_function_type_endpoint(endpoint: types.FunctionType) -> Callable: + # Function Doesn't Have @API Decorator + if not hasattr(endpoint, '__wrapped__'): + logger.critical(f'You may have forgotten to use @API() on the {endpoint.__name__}()') + raise TypeError + return endpoint + + +def check_class_type_endpoint(endpoint: Callable) -> Callable: + from panther.app import GenericAPI + + if not issubclass(endpoint, GenericAPI): + logger.critical(f'You may have forgotten to inherit from GenericAPI on the {endpoint.__name__}()') + raise TypeError + + return endpoint.call_method diff --git a/panther/configs.py b/panther/configs.py index 46c1d439..51fd59fb 100644 --- a/panther/configs.py +++ b/panther/configs.py @@ -35,8 +35,10 @@ class Config(TypedDict): default_cache_exp: timedelta | None throttling: Throttling | None secret_key: bytes | None - middlewares: list - reversed_middlewares: list + http_middlewares: list + ws_middlewares: list + reversed_http_middlewares: list + reversed_ws_middlewares: list user_model: ModelMetaclass | None authentication: ModelMetaclass | None jwt_config: JWTConfig | None diff --git a/panther/main.py b/panther/main.py index 1e1cd556..653cb182 100644 --- a/panther/main.py +++ b/panther/main.py @@ -1,7 +1,6 @@ import asyncio import contextlib import logging -import subprocess import sys import types from collections.abc import Callable @@ -12,7 +11,8 @@ import panther.logging from panther import status from panther._load_configs import * -from panther._utils import clean_traceback_message, http_response, is_function_async +from panther._utils import clean_traceback_message, http_response, is_function_async, reformat_code, \ + check_class_type_endpoint, check_function_type_endpoint from panther.background_tasks import background_tasks from panther.cli.utils import print_info from panther.configs import config @@ -22,7 +22,6 @@ from panther.response import Response from panther.routings import collect_path_variables, find_endpoint - dictConfig(panther.logging.LOGGING) logger = logging.getLogger('panther') @@ -38,7 +37,8 @@ def __init__(self, name: str, configs=None, urls: dict | None = None, startup: C try: self.load_configs() - self.reformat_code() + if config['auto_reformat']: + reformat_code(base_dir=config['base_dir']) except Exception as e: # noqa: BLE001 if isinstance(e, PantherException): logger.error(e.args[0]) @@ -55,7 +55,7 @@ def __init__(self, name: str, configs=None, urls: dict | None = None, startup: C # Start Websocket Listener (Redis Required) if config['has_ws']: Thread( - target=self.websocket_connections, + target=config['websocket_connections'], daemon=True, args=(self.ws_redis_connection,), ).start() @@ -72,8 +72,11 @@ def load_configs(self) -> None: config['background_tasks'] = load_background_tasks(self._configs_module) config['throttling'] = load_throttling(self._configs_module) config['default_cache_exp'] = load_default_cache_exp(self._configs_module) - config['middlewares'] = load_middlewares(self._configs_module) - config['reversed_middlewares'] = config['middlewares'][::-1] + middlewares = load_middlewares(self._configs_module) + config['http_middlewares'] = middlewares['http'] + config['ws_middlewares'] = middlewares['ws'] + config['reversed_http_middlewares'] = middlewares['http'][::-1] + config['reversed_ws_middlewares'] = middlewares['ws'][::-1] config['user_model'] = load_user_model(self._configs_module) config['authentication'] = load_authentication_class(self._configs_module) config['jwt_config'] = load_jwt_config(self._configs_module) @@ -107,7 +110,7 @@ def _create_ws_connections_instance(self): # Create websocket connections instance if config['has_ws']: - config['websocket_connections'] = self.websocket_connections = WebsocketConnections() + config['websocket_connections'] = WebsocketConnections() # Websocket Redis Connection for middleware in config['middlewares']: if middleware.__class__.__name__ == 'RedisMiddleware': @@ -116,15 +119,6 @@ def _create_ws_connections_instance(self): else: self.ws_redis_connection = None - @classmethod - def reformat_code(cls): - if config['auto_reformat']: - try: - subprocess.run(['ruff', 'format', config['base_dir']]) - subprocess.run(['ruff', 'check', '--select', 'I', '--fix', config['base_dir']]) - except FileNotFoundError: - raise PantherException("Module 'ruff' not found, Hint: `pip install ruff`") - async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None: """ 1. @@ -151,47 +145,64 @@ async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None async def handle_ws(self, scope: dict, receive: Callable, send: Callable) -> None: from panther.websocket import GenericWebsocket, Websocket + + # Monitoring monitoring = Monitoring(is_active=config['monitoring'], is_ws=True) + # Create Temp Connection temp_connection = Websocket(scope=scope, receive=receive, send=send) await monitoring.before(request=temp_connection) + # Find Endpoint endpoint, found_path = find_endpoint(path=temp_connection.path) if endpoint is None: await monitoring.after('Rejected') return await temp_connection.close(status.WS_1000_NORMAL_CLOSURE) + # Check Endpoint Type if not issubclass(endpoint, GenericWebsocket): logger.critical(f'You may have forgotten to inherit from GenericWebsocket on the {endpoint.__name__}()') await monitoring.after('Rejected') return await temp_connection.close(status.WS_1014_BAD_GATEWAY) + # Collect Path Variables path_variables: dict = collect_path_variables(request_path=temp_connection.path, found_path=found_path) + # Create The Connection del temp_connection connection = endpoint(scope=scope, receive=receive, send=send) connection.set_path_variables(path_variables=path_variables) # Call 'Before' Middlewares - for middleware in config['middlewares']: + if await self._run_ws_middlewares_before_listen(connection=connection): + # Only Listen() If Middlewares Didn't Raise Anything + await config['websocket_connections'].new_connection(connection=connection) + await monitoring.after('Accepted') + await connection.listen() + + # Call 'After' Middlewares + await self._run_ws_middlewares_after_listen(connection=connection) + + # Done + await monitoring.after('Closed') + return None + + @classmethod + async def _run_ws_middlewares_before_listen(cls, *, connection) -> bool: + for middleware in config['ws_middlewares']: try: connection = await middleware.before(request=connection) except APIException: await connection.close() - break - else: - await self.websocket_connections.new_connection(connection=connection) - await monitoring.after('Accepted') - await connection.listen() + return False + return True - # Call 'After' Middleware - for middleware in config['reversed_middlewares']: + @classmethod + async def _run_ws_middlewares_after_listen(cls, *, connection): + for middleware in config['reversed_ws_middlewares']: with contextlib.suppress(APIException): await middleware.after(response=connection) - await monitoring.after('Closed') - return None - async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> None: request = Request(scope=scope, receive=receive, send=send) @@ -202,75 +213,45 @@ async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> N await request.read_body() # Find Endpoint - endpoint, found_path = find_endpoint(path=request.path) - path_variables: dict = collect_path_variables(request_path=request.path, found_path=found_path) + _endpoint, found_path = find_endpoint(path=request.path) + if _endpoint is None: + return await self._raise(send, status_code=status.HTTP_404_NOT_FOUND) - if endpoint is None: - return await http_response( - send, - status_code=status.HTTP_404_NOT_FOUND, - monitoring=self.monitoring, - exception=True, - ) + # Check Endpoint Type + try: + if isinstance(_endpoint, types.FunctionType): + endpoint = check_function_type_endpoint(endpoint=_endpoint) + else: + endpoint = check_class_type_endpoint(endpoint=_endpoint) + except TypeError: + return await self._raise(send, status_code=status.HTTP_501_NOT_IMPLEMENTED) + + # Collect Path Variables + path_variables: dict = collect_path_variables(request_path=request.path, found_path=found_path) try: # They Both(middleware.before() & _endpoint()) Have The Same Exception (APIException) # Call 'Before' Middlewares - for middleware in config['middlewares']: + for middleware in config['http_middlewares']: request = await middleware.before(request=request) - # Function - if isinstance(endpoint, types.FunctionType): - # Function Doesn't Have @API Decorator - if not hasattr(endpoint, '__wrapped__'): - logger.critical(f'You may have forgotten to use @API() on the {endpoint.__name__}()') - return await http_response( - send, - status_code=status.HTTP_501_NOT_IMPLEMENTED, - monitoring=self.monitoring, - exception=True, - ) - - # Declare Endpoint - _endpoint = endpoint - - # Class - else: - from panther.app import GenericAPI - - if not issubclass(endpoint, GenericAPI): - logger.critical(f'You may have forgotten to inherit from GenericAPI on the {endpoint.__name__}()') - return await http_response( - send, - status_code=status.HTTP_501_NOT_IMPLEMENTED, - monitoring=self.monitoring, - exception=True, - ) - # Declare Endpoint - _endpoint = endpoint.call_method - # Call Endpoint - response = await _endpoint(request=request, **path_variables) + response = await endpoint(request=request, **path_variables) except APIException as e: - response = self.handle_exceptions(e) + response = self._handle_exceptions(e) + except Exception as e: # noqa: BLE001 # Every unhandled exception in Panther or code will catch here exception = clean_traceback_message(exception=e) logger.critical(exception) - - return await http_response( - send, - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - monitoring=self.monitoring, - exception=True, - ) + return await self._raise(send, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) # Call 'After' Middleware - for middleware in config['reversed_middlewares']: + for middleware in config['reversed_http_middlewares']: try: response = await middleware.after(response=response) except APIException as e: # noqa: PERF203 - response = self.handle_exceptions(e) + response = self._handle_exceptions(e) await http_response( send, @@ -298,8 +279,16 @@ def __del__(self): self.handle_shutdown() @classmethod - def handle_exceptions(cls, e: APIException, /) -> Response: + def _handle_exceptions(cls, e: APIException, /) -> Response: return Response( data=e.detail if isinstance(e.detail, dict) else {'detail': e.detail}, status_code=e.status_code, ) + + async def _raise(self, send, *, status_code: int): + await http_response( + send, + status_code=status_code, + monitoring=self.monitoring, + exception=True, + ) diff --git a/panther/middlewares/base.py b/panther/middlewares/base.py index a0095eda..32d7d2cf 100644 --- a/panther/middlewares/base.py +++ b/panther/middlewares/base.py @@ -1,10 +1,27 @@ from panther.request import Request from panther.response import Response +from panther.websocket import GenericWebsocket class BaseMiddleware: + async def before(self, request: Request | GenericWebsocket): + raise NotImplementedError + + async def after(self, response: Response | GenericWebsocket): + raise NotImplementedError + + +class HTTPMiddleware(BaseMiddleware): async def before(self, request: Request): return request async def after(self, response: Response): return response + + +class WebsocketMiddleware(BaseMiddleware): + async def before(self, request: GenericWebsocket): + return request + + async def after(self, response: GenericWebsocket): + return response diff --git a/setup.py b/setup.py index 29a31ce1..0b5b66d1 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ def panther_version() -> str: 'pymongo~=4.4', 'bpython~=0.24', 'ruff~=0.1.9', + 'websockets~=12.0', ], } diff --git a/tests/test_run.py b/tests/test_run.py index df7672cd..627b214e 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -37,13 +37,21 @@ def test_load_configs(self): assert config['throttling'].duration == timedelta(seconds=10) assert config['secret_key'] == secret_key.encode() - assert len(config['middlewares']) == 1 - assert config['middlewares'][0].__class__.__name__ == 'DatabaseMiddleware' - assert config['middlewares'][0].url == 'pantherdb://test.pdb' + assert len(config['http_middlewares']) == 1 + assert config['http_middlewares'][0].__class__.__name__ == 'DatabaseMiddleware' + assert config['http_middlewares'][0].url == 'pantherdb://test.pdb' - assert len(config['reversed_middlewares']) == 1 - assert config['reversed_middlewares'][0].__class__.__name__ == 'DatabaseMiddleware' - assert config['reversed_middlewares'][0].url == 'pantherdb://test.pdb' + assert len(config['reversed_http_middlewares']) == 1 + assert config['reversed_http_middlewares'][0].__class__.__name__ == 'DatabaseMiddleware' + assert config['reversed_http_middlewares'][0].url == 'pantherdb://test.pdb' + + assert len(config['ws_middlewares']) == 1 + assert config['ws_middlewares'][0].__class__.__name__ == 'DatabaseMiddleware' + assert config['ws_middlewares'][0].url == 'pantherdb://test.pdb' + + assert len(config['reversed_ws_middlewares']) == 1 + assert config['reversed_ws_middlewares'][0].__class__.__name__ == 'DatabaseMiddleware' + assert config['reversed_ws_middlewares'][0].url == 'pantherdb://test.pdb' assert config['user_model'].__name__ == tests.sample_project.app.models.User.__name__ assert config['user_model'].__module__.endswith('app.models')