diff --git a/django_rq/management/commands/rqworker-pool.py b/django_rq/management/commands/rqworker-pool.py index a7329801..7339e2c0 100644 --- a/django_rq/management/commands/rqworker-pool.py +++ b/django_rq/management/commands/rqworker-pool.py @@ -1,9 +1,8 @@ import os import sys -from rq.serializers import resolve_serializer -from rq.worker_pool import WorkerPool from rq.logutils import setup_loghandlers +from rq.serializers import resolve_serializer from django.core.management.base import BaseCommand @@ -11,6 +10,7 @@ from ...utils import configure_sentry from ...queues import get_queues from ...workers import get_worker_class +from ...worker_pool import DjangoWorkerPool class Command(BaseCommand): @@ -89,7 +89,7 @@ def handle(self, *args, **options): worker_class = get_worker_class(options.get('worker_class', None)) serializer = resolve_serializer(options['serializer']) - pool = WorkerPool( + pool = DjangoWorkerPool( queues=queues, connection=queues[0].connection, num_workers=options['num_workers'], diff --git a/django_rq/tests/tests.py b/django_rq/tests/tests.py index 50f2733d..30151ace 100644 --- a/django_rq/tests/tests.py +++ b/django_rq/tests/tests.py @@ -1,5 +1,6 @@ -import sys import datetime +import multiprocessing +import sys import time from unittest import skipIf, mock from unittest.mock import patch, PropertyMock, MagicMock @@ -37,6 +38,8 @@ from django_rq.utils import get_jobs, get_statistics, get_scheduler_pid from django_rq.workers import get_worker, get_worker_class +from .utils import query_queue + try: from rq_scheduler import Scheduler from ..queues import get_scheduler @@ -303,6 +306,18 @@ def test_pass_queue_via_commandline_args(self): self.assertTrue(job['job'].is_finished) self.assertIn(job['job'].id, job['finished_job_registry'].get_job_ids()) + def test_rqworker_pool_process_start_method(self) -> None: + for start_method in ['spawn', 'fork']: + with mock.patch.object(multiprocessing, "get_start_method", return_value=start_method): + queue_name = 'django_rq_test' + queue = get_queue(queue_name) + job = queue.enqueue(query_queue) + finished_job_registry = FinishedJobRegistry(queue.name, queue.connection) + call_command('rqworker-pool', queue_name, burst=True) + + self.assertTrue(job.is_finished) + self.assertIn(job.id, finished_job_registry.get_job_ids()) + def test_configure_sentry(self): rqworker.configure_sentry('https://1@sentry.io/1') self.mock_sdk.init.assert_called_once_with( diff --git a/django_rq/tests/utils.py b/django_rq/tests/utils.py index afe4df2a..e33754de 100644 --- a/django_rq/tests/utils.py +++ b/django_rq/tests/utils.py @@ -1,4 +1,5 @@ from django_rq.queues import get_connection, get_queue_by_index +from django_rq.models import Queue def get_queue_index(name='default'): @@ -17,3 +18,7 @@ def get_queue_index(name='default'): queue_index = i break return queue_index + + +def query_queue(): + return Queue.objects.first() diff --git a/django_rq/worker_pool.py b/django_rq/worker_pool.py new file mode 100644 index 00000000..68876cb4 --- /dev/null +++ b/django_rq/worker_pool.py @@ -0,0 +1,38 @@ +import django +from multiprocessing import Process, get_start_method +from typing import Any + +from rq.worker_pool import WorkerPool, run_worker + + +class DjangoWorkerPool(WorkerPool): + def get_worker_process( + self, + name: str, + burst: bool, + _sleep: float = 0, + logging_level: str = "INFO", + ) -> Process: + """Returns the worker process""" + return Process( + target=run_django_worker, + args=(name, self._queue_names, self._connection_class, self._pool_class, self._pool_kwargs), + kwargs={ + '_sleep': _sleep, + 'burst': burst, + 'logging_level': logging_level, + 'worker_class': self.worker_class, + 'job_class': self.job_class, + 'serializer': self.serializer, + }, + name=f'Worker {name} (WorkerPool {self.name})', + ) + + +def run_django_worker(*args: Any, **kwargs: Any) -> None: + # multiprocessing library default process start method may be + # `spawn` or `fork` depending on the host OS + if get_start_method() == 'spawn': + django.setup() + + run_worker(*args, **kwargs)