diff --git a/drf_model_pusher/apps.py b/drf_model_pusher/apps.py index 57b1313..b714d7a 100644 --- a/drf_model_pusher/apps.py +++ b/drf_model_pusher/apps.py @@ -9,7 +9,9 @@ class DrfModelPusherConfig(AppConfig): name = "drf_model_pusher" def ready(self): + """Attach receivers to Signals and import pusher backends.""" from drf_model_pusher.config import connect_pusher_views + connect_pusher_views() pusher_backends_file = "pusher_backends.py" @@ -22,4 +24,3 @@ def ready(self): for app_config in apps.get_app_configs(): if os.path.exists(os.path.join(app_config.path, pusher_backends_file)): import_module("{0}.pusher_backends".format(app_config.name)) - diff --git a/drf_model_pusher/backends.py b/drf_model_pusher/backends.py index e9a99e0..c1a0dfe 100644 --- a/drf_model_pusher/backends.py +++ b/drf_model_pusher/backends.py @@ -3,6 +3,7 @@ """ from collections import defaultdict +from drf_model_pusher.providers import PusherProvider from drf_model_pusher.signals import view_pre_destroy, view_post_save pusher_backend_registry = defaultdict(list) @@ -11,7 +12,7 @@ class PusherBackendMetaclass(type): """ Register PusherBackend's with a registry for model lookups, supports - abstract classes + "abstract" classes which are not registered but can extend functionality. """ def __new__(mcs, cls, bases, dicts): @@ -39,15 +40,20 @@ class PusherBackend(metaclass=PusherBackendMetaclass): class Meta: abstract = True + provider_class = PusherProvider + def __init__(self, view): self.view = view - self.pusher_socket_id = self.get_pusher_socket(view) + self.socket_id = self.get_socket_id(view) + self.provider = self.provider_class() - def get_pusher_socket(self, view): + def get_socket_id(self, view): + """Return the socket from the request header.""" pusher_socket = view.request.META.get("HTTP_X_PUSHER_SOCKET_ID", None) return pusher_socket def push_change(self, event, instance=None, pre_destroy=False, ignore=True): + """Send a signal to push the update""" channels, event_name, data = self.get_packet(event, instance) if pre_destroy: view_pre_destroy.send( @@ -56,7 +62,8 @@ def push_change(self, event, instance=None, pre_destroy=False, ignore=True): channels=channels, event_name=event_name, data=data, - socket_id=self.pusher_socket_id if ignore else None, + socket_id=self.socket_id if ignore else None, + provider_class=self.provider_class, ) else: view_post_save.send( @@ -65,7 +72,8 @@ def push_change(self, event, instance=None, pre_destroy=False, ignore=True): channels=channels, event_name=event_name, data=data, - socket_id=self.pusher_socket_id if ignore else None, + socket_id=self.socket_id if ignore else None, + provider_class=self.provider_class, ) def get_event_name(self, event_type): @@ -75,13 +83,17 @@ def get_event_name(self, event_type): return "{0}.{1}".format(model_class_name, event_type) def get_serializer_class(self): - """Return the views serializer class""" + """Return the serializer class""" return self.view.get_serializer_class() - def get_serializer(self, view, *args, **kwargs): - """Return the serializer initialized with the views serializer context""" + def get_serializer_context(self): + """Return the context for the serializer.""" + return self.view.get_serializer_context() + + def get_serializer(self, *args, **kwargs): + """Return the serializer initialized with the serializer context""" serializer_class = self.get_serializer_class() - kwargs["context"] = view.get_serializer_context() + kwargs["context"] = self.get_serializer_context() return serializer_class(*args, **kwargs) def get_channels(self, instance=None): @@ -93,7 +105,8 @@ def get_packet(self, event, instance): """Return a tuple consisting of the channel, event name, and the JSON serializable data.""" channels = self.get_channels(instance=instance) event_name = self.get_event_name(event) - data = self.get_serializer(self.view, instance=instance).data + data = self.get_serializer(instance=instance).data + channels, event_name, data = self.provider.parse_packet(self, channels, event_name, data, self.socket_id) return channels, event_name, data @@ -104,10 +117,10 @@ class PrivatePusherBackend(PusherBackend): class Meta: abstract = True - def get_channel(self, instance=None): + def get_channels(self, instance=None): """Return the channel prefixed with `private-`""" - channel = super().get_channel(instance=instance) - return "private-{channel}".format(channel=channel) + channel = super().get_channels(instance=instance) + return ["private-{channel}".format(channel=channel)] def get_models_pusher_backends(model): diff --git a/drf_model_pusher/providers.py b/drf_model_pusher/providers.py new file mode 100644 index 0000000..2806cb9 --- /dev/null +++ b/drf_model_pusher/providers.py @@ -0,0 +1,60 @@ + +from django.conf import settings +from pusher import Pusher + + +class BaseProvider(object): + def configure(self): + raise NotImplementedError() + + def parse_packet(self, backend, channels, event_name, data, socket_id=None): + return channels, event_name, data + + def trigger(self, channels, event_name, data, socket_id=None): + raise NotImplementedError() + + +class PusherProvider(BaseProvider): + """ + This class provides a wrapper to Pusher so that we can mock it or disable it easily + """ + + def __init__(self): + self._pusher = None + self._disabled = False + + if hasattr(settings, "DRF_MODEL_PUSHER_DISABLED"): + self._disabled = settings.DRF_MODEL_PUSHER_DISABLED + + def configure(self): + try: + pusher_cluster = settings.PUSHER_CLUSTER + except AttributeError: + pusher_cluster = "mt1" + + self._pusher = Pusher( + app_id=settings.PUSHER_APP_ID, + key=settings.PUSHER_KEY, + secret=settings.PUSHER_SECRET, + cluster=pusher_cluster, + ) + + def parse_packet(self, backend, channels, event_name, data, socket_id=None): + return channels, event_name, data + + def trigger(self, channels, event_name, data, socket_id=None): + if self._disabled: + return + + self._pusher.trigger(channels, event_name, data, socket_id) + + +class AblyProvider(BaseProvider): + def __init__(self, *args, **kwargs): + pass + + def configure(self): + pass + + def trigger(self, channels, event_name, data, socket_id=None): + pass diff --git a/drf_model_pusher/proxies.py b/drf_model_pusher/proxies.py deleted file mode 100644 index 8d76bab..0000000 --- a/drf_model_pusher/proxies.py +++ /dev/null @@ -1,20 +0,0 @@ -from django.conf import settings -from pusher import Pusher - - -class PusherProxy(object): - """ - This class provides a wrapper to Pusher so that we can mock it or disable it easily - """ - def __init__(self, *args, **kwargs): - self._pusher = Pusher(*args, **kwargs) - self._disabled = False - - if hasattr(settings, "DRF_MODEL_PUSHER_DISABLED"): - self._disabled = settings.DRF_MODEL_PUSHER_DISABLED - - def trigger(self, channels, event_name, data): - if self._disabled: - return - - self._pusher.trigger(channels, event_name, data) \ No newline at end of file diff --git a/drf_model_pusher/receivers.py b/drf_model_pusher/receivers.py index b980e77..f6e0dd7 100644 --- a/drf_model_pusher/receivers.py +++ b/drf_model_pusher/receivers.py @@ -1,23 +1,15 @@ -from django.conf import settings - -from drf_model_pusher.proxies import PusherProxy +"""The receiver methods attach callbacks to signals""" +from drf_model_pusher.providers import PusherProvider def send_pusher_event( signal, sender, instance, channels, event_name, data, socket_id=None, **kwargs ): """ - Send a pusher event from a signal + Sends an update using the provided provider class """ - try: - pusher_cluster = settings.PUSHER_CLUSTER - except AttributeError: - pusher_cluster = "mt1" - pusher = PusherProxy( - app_id=settings.PUSHER_APP_ID, - key=settings.PUSHER_KEY, - secret=settings.PUSHER_SECRET, - cluster=pusher_cluster, - ) - pusher.trigger(channels, event_name, data) + push_provider_class = kwargs.get("provider_class", PusherProvider) + push_provider = push_provider_class() + push_provider.configure() + push_provider.trigger(channels, event_name, data, socket_id) diff --git a/drf_model_pusher/views.py b/drf_model_pusher/views.py index fe57ea3..ac9ba94 100644 --- a/drf_model_pusher/views.py +++ b/drf_model_pusher/views.py @@ -1,5 +1,3 @@ -from typing import List - from drf_model_pusher.backends import get_models_pusher_backends from drf_model_pusher.exceptions import ModelPusherException from drf_model_pusher.signals import view_post_save @@ -29,14 +27,17 @@ def get_models_pusher_backends(self): elif hasattr(self, "get_queryset"): model = self.get_queryset().model else: - raise ModelPusherException("View must have a queryset attribute or get_queryset method defined") + raise ModelPusherException( + "View must have a queryset attribute or get_queryset method defined" + ) return get_models_pusher_backends(model) - def get_pusher_channels(self) -> List[str]: + def get_pusher_channels(self): """Return the channel from the view""" raise NotImplementedError( "{0} must implement the `get_pusher_channels` method".format( - self.__class__.__name__) + self.__class__.__name__ + ) ) def get_pusher_backends(self): diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 index 5c7ae59..51c2ace --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ EMAIL = "aljparr0@gmail.com" AUTHOR = "Adam Jacquier-Parr" REQUIRES_PYTHON = ">=3.6.0" -VERSION = "0.1.0" +VERSION = "0.2.0" # What packages are required for this module to be executed? REQUIRED = ["django", "djangorestframework", "pusher"] diff --git a/tests/test_views.py b/tests/test_views.py index 267b6ae..81a33de 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -26,7 +26,7 @@ def test_creations_are_pushed(self, trigger: Mock): self.assertEqual(response.status_code, 201, response.data) trigger.assert_called_once_with( - ["channel"], "mymodel.create", MyModelSerializer(instance=instance).data + ["channel"], "mymodel.create", MyModelSerializer(instance=instance).data, None ) @mock.patch("pusher.Pusher.trigger") @@ -46,7 +46,7 @@ def test_updates_are_pushed(self, trigger: Mock): self.assertEqual(instance.name, "Michelle") trigger.assert_called_once_with( - ["channel"], "mymodel.update", MyModelSerializer(instance=instance).data + ["channel"], "mymodel.update", MyModelSerializer(instance=instance).data, None ) @mock.patch("pusher.Pusher.trigger") @@ -64,5 +64,5 @@ def test_deletions_are_pushed(self, trigger: Mock): instance = MyModel.objects.get(pk=instance.pk) trigger.assert_called_once_with( - ["channel"], "mymodel.delete", MyModelSerializer(instance=instance).data + ["channel"], "mymodel.delete", MyModelSerializer(instance=instance).data, None )