Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket Queue #64

Merged
merged 7 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/docs/release_notes.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
### 3.8.0
- Handle WebSocket connections when we have multiple workers with `multiprocessing.Manager`

### 3.7.0
- Add `ModelSerializer`

Expand Down
16 changes: 11 additions & 5 deletions docs/docs/websocket.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,15 @@ urls = {
from panther.websocket import send_message_to_websocket
await send_message_to_websocket(connection_id='7e82d57c9ec0478787b01916910a9f45', data='New Message From WS')
```
8. If you want to use `webscoket` in `multi-tread` or `multi-instance` backend, you should add `RedisMiddleware` in your `configs` or it won't work well.
8. If you want to use `webscoket` in a backend with `multiple workers`, we recommend you to add `RedisMiddleware` in your `configs`
[[Adding Redis Middleware]](https://pantherpy.github.io/middlewares/#redis-middleware)
9. If you want to close a connection:
9. If you don't want to add `RedisMiddleware` and you still want to use `websocket` in `multi-thread`,
you have to use `--preload` option while running the project like below:
```bash
gunicorn -w 10 -k uvicorn.workers.UvicornWorker main:app --preload
```

10. If you want to close a connection:
- In websocket class scope: You can close connection with `self.close()` method which takes 2 args, `code` and `reason`:
```python
from panther import status
Expand All @@ -65,7 +71,7 @@ urls = {
await close_websocket_connection(connection_id='7e82d57c9ec0478787b01916910a9f45', code=status.WS_1008_POLICY_VIOLATION, reason='')
```

10. `Path Variables` will be passed to `connect()`:
11. `Path Variables` will be passed to `connect()`:
```python
from panther.websocket import GenericWebsocket

Expand All @@ -77,6 +83,6 @@ urls = {
'/ws/<user_id>/<room_id>/': UserWebsocket
}
```
11. WebSocket Echo Example -> [Https://GitHub.com/PantherPy/echo_websocket](https://github.com/PantherPy/echo_websocket)
12. Enjoy.
12. WebSocket Echo Example -> [Https://GitHub.com/PantherPy/echo_websocket](https://github.com/PantherPy/echo_websocket)
13. Enjoy.

2 changes: 1 addition & 1 deletion example/model_serializer_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import Field

from panther import status, Panther
from panther import Panther, status
from panther.app import API
from panther.db import Model
from panther.request import Request
Expand Down
2 changes: 1 addition & 1 deletion panther/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from panther.main import Panther # noqa: F401

__version__ = '3.7.0'
__version__ = '3.8.0'


def version():
Expand Down
101 changes: 71 additions & 30 deletions panther/base_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,67 +3,106 @@
import asyncio
import contextlib
import logging
from typing import TYPE_CHECKING
from multiprocessing import Manager
from typing import TYPE_CHECKING, Literal

import orjson as json

from panther import status
from panther._utils import generate_ws_connection_id
from panther.base_request import BaseRequest
from panther.configs import config
from panther.db.connection import redis
from panther.utils import Singleton

if TYPE_CHECKING:
from redis import Redis


logger = logging.getLogger('panther')


class PubSub:
def __init__(self, manager):
self._manager = manager
self._subscribers = self._manager.list()

def subscribe(self):
queue = self._manager.Queue()
self._subscribers.append(queue)
return queue

def publish(self, msg):
for queue in self._subscribers:
queue.put(msg)


class WebsocketConnections(Singleton):
def __init__(self):
def __init__(self, manager: Manager = None):
self.connections = {}
self.connections_count = 0
self.manager = manager

def __call__(self, r: Redis | None):
if r:
subscriber = r.pubsub()
subscriber.subscribe('websocket_connections')
logger.info("Subscribed to 'websocket_connections' channel")
for channel_data in subscriber.listen():
# Check Type of PubSub Message
match channel_data['type']:
# Subscribed
case 'subscribe':
continue

# Message Received
case 'message':
loaded_data = json.loads(channel_data['data'].decode())
if (
isinstance(loaded_data, dict)
and (connection_id := loaded_data.get('connection_id'))
and (data := loaded_data.get('data'))
and (action := loaded_data.get('action'))
and (connection := self.connections.get(connection_id))
):
# Check Action of WS
match action:
case 'send':
logger.debug(f'Sending Message to {connection_id}')
asyncio.run(connection.send(data=data))
case 'close':
with contextlib.suppress(RuntimeError):
asyncio.run(connection.close(code=data['code'], reason=data['reason']))
# We are trying to disconnect the connection between a thread and a user
# from another thread, it's working, but we have to find another solution it
#
# Error:
# Task <Task pending coro=<Websocket.close()>> got Future
# <Task pending coro=<WebSocketCommonProtocol.transfer_data()>>
# attached to a different loop
case _:
logger.debug(f'Unknown Message Action: {action}')
case _:
logger.debug(f'Unknown Channel Type: {channel_data["type"]}')
self._handle_received_message(received_message=loaded_data)

case unknown_type:
logger.debug(f'Unknown Channel Type: {unknown_type}')
else:
self.pubsub = PubSub(manager=self.manager)
queue = self.pubsub.subscribe()
logger.info("Subscribed to 'websocket_connections' queue")
while True:
received_message = queue.get()
self._handle_received_message(received_message=received_message)

def _handle_received_message(self, received_message):
if (
isinstance(received_message, dict)
and (connection_id := received_message.get('connection_id'))
and connection_id in self.connections
and 'action' in received_message
and 'data' in received_message
):
# Check Action of WS
match received_message['action']:
case 'send':
asyncio.run(self.connections[connection_id].send(data=received_message['data']))
case 'close':
with contextlib.suppress(RuntimeError):
asyncio.run(self.connections[connection_id].close(
code=received_message['data']['code'],
reason=received_message['data']['reason']
))
# We are trying to disconnect the connection between a thread and a user
# from another thread, it's working, but we have to find another solution for it
#
# Error:
# Task <Task pending coro=<Websocket.close()>> got Future
# <Task pending coro=<WebSocketCommonProtocol.transfer_data()>>
# attached to a different loop
case unknown_action:
logger.debug(f'Unknown Message Action: {unknown_action}')

def publish(self, connection_id: str, action: Literal['send', 'close'], data: any):
publish_data = {'connection_id': connection_id, 'action': action, 'data': data}

if redis.is_connected:
redis.publish('websocket_connections', json.dumps(publish_data))
else:
self.pubsub.publish(publish_data)

async def new_connection(self, connection: Websocket) -> None:
await connection.connect(**connection.path_variables)
Expand Down Expand Up @@ -106,6 +145,7 @@ async def receive(self, data: str | bytes) -> None:
pass

async def send(self, data: any = None) -> None:
logger.debug(f'Sending WS Message to {self.connection_id}')
if data:
if isinstance(data, bytes):
await self.send_bytes(bytes_data=data)
Expand All @@ -121,6 +161,7 @@ async def send_bytes(self, bytes_data: bytes) -> None:
await self.asgi_send({'type': 'websocket.send', 'bytes': bytes_data})

async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE, reason: str = '') -> None:
logger.debug(f'Closing WS Connection {self.connection_id}')
self.is_connected = False
config['websocket_connections'].remove_connection(self)
await self.asgi_send({'type': 'websocket.close', 'code': code, 'reason': reason})
Expand Down
35 changes: 23 additions & 12 deletions panther/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import types
from collections.abc import Callable
from logging.config import dictConfig
from multiprocessing import Manager
from pathlib import Path
from threading import Thread

Expand Down Expand Up @@ -52,14 +53,6 @@ def __init__(self, name: str, configs=None, urls: dict | None = None, startup: C
# Print Info
print_info(config)

# Start Websocket Listener (Redis Required)
if config['has_ws']:
Thread(
target=config['websocket_connections'],
daemon=True,
args=(self.ws_redis_connection,),
).start()

def load_configs(self) -> None:

# Check & Read The Configs File
Expand Down Expand Up @@ -98,8 +91,7 @@ def load_configs(self) -> None:
self._create_ws_connections_instance()

def _create_ws_connections_instance(self):
from panther.base_websocket import Websocket
from panther.websocket import WebsocketConnections
from panther.base_websocket import Websocket, WebsocketConnections

# Check do we have ws endpoint
for endpoint in config['flat_urls'].values():
Expand All @@ -111,7 +103,6 @@ def _create_ws_connections_instance(self):

# Create websocket connections instance
if config['has_ws']:
config['websocket_connections'] = WebsocketConnections()
# Websocket Redis Connection
for middleware in config['http_middlewares']:
if middleware.__class__.__name__ == 'RedisMiddleware':
Expand All @@ -120,6 +111,10 @@ def _create_ws_connections_instance(self):
else:
self.ws_redis_connection = None

# Don't create Manager() if we are going to use Redis for PubSub
manager = None if self.ws_redis_connection else Manager()
config['websocket_connections'] = WebsocketConnections(manager=manager)

async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None:
"""
1.
Expand All @@ -138,6 +133,7 @@ async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None
if scope['type'] == 'lifespan':
message = await receive()
if message["type"] == "lifespan.startup":
await self.handle_ws_listener()
await self.handle_startup()
return

Expand Down Expand Up @@ -262,6 +258,15 @@ async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> N
body=response.body,
)

async def handle_ws_listener(self):
# Start Websocket Listener (Redis/ Queue)
if config['has_ws']:
Thread(
target=config['websocket_connections'],
daemon=True,
args=(self.ws_redis_connection,),
).start()

async def handle_startup(self):
if startup := config['startup'] or self._startup:
if is_function_async(startup):
Expand All @@ -272,7 +277,13 @@ async def handle_startup(self):
def handle_shutdown(self):
if shutdown := config['shutdown'] or self._shutdown:
if is_function_async(shutdown):
asyncio.run(shutdown())
try:
asyncio.run(shutdown())
except ModuleNotFoundError:
# Error: import of asyncio halted; None in sys.modules
# And as I figured it out, it only happens when we running with
# gunicorn and Uvicorn workers (-k uvicorn.workers.UvicornWorker)
pass
else:
shutdown()

Expand Down
3 changes: 2 additions & 1 deletion panther/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def body(self) -> bytes:

@property
def headers(self) -> dict:
content_length = 0 if self.body == b'null' else len(self.body)
return {
'content-type': self.content_type,
'content-length': len(self.body),
'content-length': content_length,
'access-control-allow-origin': '*',
} | (self._headers or {})

Expand Down
39 changes: 26 additions & 13 deletions panther/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,44 @@


class ModelSerializer:
def __new__(cls, *args, **kwargs):
def __new__(cls, *args, model=None, **kwargs):
# Check `metaclass`
if len(args) == 0:
msg = f"you should not inherit the 'ModelSerializer', you should use it as 'metaclass' -> {cls.__name__}"
address = f'{cls.__module__}.{cls.__name__}'
msg = f"you should not inherit the 'ModelSerializer', you should use it as 'metaclass' -> {address}"
raise TypeError(msg)

model_name = args[0]
if 'model' not in kwargs:
msg = f"'model' required while using 'ModelSerializer' metaclass -> {model_name}"
data = args[2]
address = f'{data["__module__"]}.{model_name}'

# Check `model`
if model is None:
msg = f"'model' required while using 'ModelSerializer' metaclass -> {address}"
raise AttributeError(msg)
# Check `fields`
if 'fields' not in data:
msg = f"'fields' required while using 'ModelSerializer' metaclass. -> {address}"
raise AttributeError(msg) from None

model_fields = kwargs['model'].model_fields
model_fields = model.model_fields
field_definitions = {}
if 'fields' not in args[2]:
msg = f"'fields' required while using 'ModelSerializer' metaclass. -> {model_name}"
raise AttributeError(msg) from None
for field_name in args[2]['fields']:

# Collect `fields`
for field_name in data['fields']:
if field_name not in model_fields:
msg = f"'{field_name}' is not in '{kwargs['model'].__name__}' -> {model_name}"
msg = f"'{field_name}' is not in '{model.__name__}' -> {address}"
raise AttributeError(msg) from None

field_definitions[field_name] = (model_fields[field_name].annotation, model_fields[field_name])
for required in args[2].get('required_fields', []):

# Change `required_fields
for required in data.get('required_fields', []):
if required not in field_definitions:
msg = f"'{required}' is in 'required_fields' but not in 'fields' -> {model_name}"
msg = f"'{required}' is in 'required_fields' but not in 'fields' -> {address}"
raise AttributeError(msg) from None
field_definitions[required][1].default = PydanticUndefined

# Create Model
return create_model(
__model_name=model_name,
**field_definitions
Expand Down
Loading