Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for more providers (Ably, Channels, SocketIO) #26

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion drf_model_pusher/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ class DrfModelPusherConfig(AppConfig):
name = "drf_model_pusher"

def ready(self):
"""Attach receivers to Signals and import pusher backends."""
from drf_model_pusher.config import connect_pusher_views

connect_pusher_views()

pusher_backends_file = "pusher_backends.py"
Expand All @@ -22,4 +24,3 @@ def ready(self):
for app_config in apps.get_app_configs():
if os.path.exists(os.path.join(app_config.path, pusher_backends_file)):
import_module("{0}.pusher_backends".format(app_config.name))

39 changes: 26 additions & 13 deletions drf_model_pusher/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from collections import defaultdict

from drf_model_pusher.providers import PusherProvider
from drf_model_pusher.signals import view_pre_destroy, view_post_save

pusher_backend_registry = defaultdict(list)
Expand All @@ -11,7 +12,7 @@
class PusherBackendMetaclass(type):
"""
Register PusherBackend's with a registry for model lookups, supports
abstract classes
"abstract" classes which are not registered but can extend functionality.
"""

def __new__(mcs, cls, bases, dicts):
Expand Down Expand Up @@ -39,15 +40,20 @@ class PusherBackend(metaclass=PusherBackendMetaclass):
class Meta:
abstract = True

provider_class = PusherProvider

def __init__(self, view):
self.view = view
self.pusher_socket_id = self.get_pusher_socket(view)
self.socket_id = self.get_socket_id(view)
self.provider = self.provider_class()

def get_pusher_socket(self, view):
def get_socket_id(self, view):
"""Return the socket from the request header."""
pusher_socket = view.request.META.get("HTTP_X_PUSHER_SOCKET_ID", None)
return pusher_socket

def push_change(self, event, instance=None, pre_destroy=False, ignore=True):
"""Send a signal to push the update"""
channels, event_name, data = self.get_packet(event, instance)
if pre_destroy:
view_pre_destroy.send(
Expand All @@ -56,7 +62,8 @@ def push_change(self, event, instance=None, pre_destroy=False, ignore=True):
channels=channels,
event_name=event_name,
data=data,
socket_id=self.pusher_socket_id if ignore else None,
socket_id=self.socket_id if ignore else None,
provider_class=self.provider_class,
)
else:
view_post_save.send(
Expand All @@ -65,7 +72,8 @@ def push_change(self, event, instance=None, pre_destroy=False, ignore=True):
channels=channels,
event_name=event_name,
data=data,
socket_id=self.pusher_socket_id if ignore else None,
socket_id=self.socket_id if ignore else None,
provider_class=self.provider_class,
)

def get_event_name(self, event_type):
Expand All @@ -75,13 +83,17 @@ def get_event_name(self, event_type):
return "{0}.{1}".format(model_class_name, event_type)

def get_serializer_class(self):
"""Return the views serializer class"""
"""Return the serializer class"""
return self.view.get_serializer_class()

def get_serializer(self, view, *args, **kwargs):
"""Return the serializer initialized with the views serializer context"""
def get_serializer_context(self):
"""Return the context for the serializer."""
return self.view.get_serializer_context()

def get_serializer(self, *args, **kwargs):
"""Return the serializer initialized with the serializer context"""
serializer_class = self.get_serializer_class()
kwargs["context"] = view.get_serializer_context()
kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs)

def get_channels(self, instance=None):
Expand All @@ -93,7 +105,8 @@ def get_packet(self, event, instance):
"""Return a tuple consisting of the channel, event name, and the JSON serializable data."""
channels = self.get_channels(instance=instance)
event_name = self.get_event_name(event)
data = self.get_serializer(self.view, instance=instance).data
data = self.get_serializer(instance=instance).data
channels, event_name, data = self.provider.parse_packet(self, channels, event_name, data, self.socket_id)
return channels, event_name, data


Expand All @@ -104,10 +117,10 @@ class PrivatePusherBackend(PusherBackend):
class Meta:
abstract = True

def get_channel(self, instance=None):
def get_channels(self, instance=None):
"""Return the channel prefixed with `private-`"""
channel = super().get_channel(instance=instance)
return "private-{channel}".format(channel=channel)
channel = super().get_channels(instance=instance)
return ["private-{channel}".format(channel=channel)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are multiple channels returned from get_channels then the return result is an array of 1 string "private-[a, b, c]" rather than a array ["private-a", "private-b", "private-c"]



def get_models_pusher_backends(model):
Expand Down
60 changes: 60 additions & 0 deletions drf_model_pusher/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

from django.conf import settings
from pusher import Pusher


class BaseProvider(object):
def configure(self):
raise NotImplementedError()

def parse_packet(self, backend, channels, event_name, data, socket_id=None):
return channels, event_name, data

def trigger(self, channels, event_name, data, socket_id=None):
raise NotImplementedError()


class PusherProvider(BaseProvider):
"""
This class provides a wrapper to Pusher so that we can mock it or disable it easily
"""

def __init__(self):
self._pusher = None
self._disabled = False

if hasattr(settings, "DRF_MODEL_PUSHER_DISABLED"):
self._disabled = settings.DRF_MODEL_PUSHER_DISABLED

def configure(self):
try:
pusher_cluster = settings.PUSHER_CLUSTER
except AttributeError:
pusher_cluster = "mt1"

self._pusher = Pusher(
app_id=settings.PUSHER_APP_ID,
key=settings.PUSHER_KEY,
secret=settings.PUSHER_SECRET,
cluster=pusher_cluster,
)

def parse_packet(self, backend, channels, event_name, data, socket_id=None):
return channels, event_name, data

def trigger(self, channels, event_name, data, socket_id=None):
if self._disabled:
return

self._pusher.trigger(channels, event_name, data, socket_id)


class AblyProvider(BaseProvider):
def __init__(self, *args, **kwargs):
pass

def configure(self):
pass

def trigger(self, channels, event_name, data, socket_id=None):
pass
20 changes: 0 additions & 20 deletions drf_model_pusher/proxies.py

This file was deleted.

22 changes: 7 additions & 15 deletions drf_model_pusher/receivers.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
from django.conf import settings

from drf_model_pusher.proxies import PusherProxy
"""The receiver methods attach callbacks to signals"""
from drf_model_pusher.providers import PusherProvider


def send_pusher_event(
signal, sender, instance, channels, event_name, data, socket_id=None, **kwargs
):
"""
Send a pusher event from a signal
Sends an update using the provided provider class
"""
try:
pusher_cluster = settings.PUSHER_CLUSTER
except AttributeError:
pusher_cluster = "mt1"

pusher = PusherProxy(
app_id=settings.PUSHER_APP_ID,
key=settings.PUSHER_KEY,
secret=settings.PUSHER_SECRET,
cluster=pusher_cluster,
)
pusher.trigger(channels, event_name, data)
push_provider_class = kwargs.get("provider_class", PusherProvider)
push_provider = push_provider_class()
push_provider.configure()
push_provider.trigger(channels, event_name, data, socket_id)
11 changes: 6 additions & 5 deletions drf_model_pusher/views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

from drf_model_pusher.backends import get_models_pusher_backends
from drf_model_pusher.exceptions import ModelPusherException
from drf_model_pusher.signals import view_post_save
Expand Down Expand Up @@ -29,14 +27,17 @@ def get_models_pusher_backends(self):
elif hasattr(self, "get_queryset"):
model = self.get_queryset().model
else:
raise ModelPusherException("View must have a queryset attribute or get_queryset method defined")
raise ModelPusherException(
"View must have a queryset attribute or get_queryset method defined"
)
return get_models_pusher_backends(model)

def get_pusher_channels(self) -> List[str]:
def get_pusher_channels(self):
"""Return the channel from the view"""
raise NotImplementedError(
"{0} must implement the `get_pusher_channels` method".format(
self.__class__.__name__)
self.__class__.__name__
)
)

def get_pusher_backends(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
EMAIL = "[email protected]"
AUTHOR = "Adam Jacquier-Parr"
REQUIRES_PYTHON = ">=3.6.0"
VERSION = "0.1.0"
VERSION = "0.2.0"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to bump this to 1.0.0 in this PR since we are probably introducing breaking changes.


# What packages are required for this module to be executed?
REQUIRED = ["django", "djangorestframework", "pusher"]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_creations_are_pushed(self, trigger: Mock):
self.assertEqual(response.status_code, 201, response.data)

trigger.assert_called_once_with(
["channel"], "mymodel.create", MyModelSerializer(instance=instance).data
["channel"], "mymodel.create", MyModelSerializer(instance=instance).data, None
)

@mock.patch("pusher.Pusher.trigger")
Expand All @@ -46,7 +46,7 @@ def test_updates_are_pushed(self, trigger: Mock):
self.assertEqual(instance.name, "Michelle")

trigger.assert_called_once_with(
["channel"], "mymodel.update", MyModelSerializer(instance=instance).data
["channel"], "mymodel.update", MyModelSerializer(instance=instance).data, None
)

@mock.patch("pusher.Pusher.trigger")
Expand All @@ -64,5 +64,5 @@ def test_deletions_are_pushed(self, trigger: Mock):
instance = MyModel.objects.get(pk=instance.pk)

trigger.assert_called_once_with(
["channel"], "mymodel.delete", MyModelSerializer(instance=instance).data
["channel"], "mymodel.delete", MyModelSerializer(instance=instance).data, None
)