diff --git a/pypeln/process/queue.py b/pypeln/process/queue.py index 5c8b5dd..d9bdcd1 100644 --- a/pypeln/process/queue.py +++ b/pypeln/process/queue.py @@ -1,9 +1,14 @@ -import multiprocessing -from multiprocessing.queues import Empty, Queue import sys import traceback import typing as tp +if "multiprocess" in sys.modules: + from multiprocess import get_context + from multiprocess.queues import Empty, Queue +else: + from multiprocessing import get_context + from multiprocessing.queues import Empty, Queue + from pypeln import utils as pypeln_utils @@ -18,13 +23,13 @@ class IterableQueue(Queue, tp.Generic[T], tp.Iterable[T]): def __init__(self, maxsize: int = 0, total_sources: int = 1): - super().__init__(maxsize=maxsize, ctx=multiprocessing.get_context()) + super().__init__(maxsize=maxsize, ctx=get_context()) self.namespace = utils.Namespace( remaining=total_sources, exception=False, force_stop=False ) self.exception_queue: Queue[PipelineException] = Queue( - ctx=multiprocessing.get_context() + ctx=get_context() ) def get(self, block: bool = True, timeout: tp.Optional[float] = None) -> T: diff --git a/pypeln/process/supervisor.py b/pypeln/process/supervisor.py index d262cf5..4db9d68 100644 --- a/pypeln/process/supervisor.py +++ b/pypeln/process/supervisor.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -import multiprocessing import threading import time import typing as tp diff --git a/pypeln/process/utils.py b/pypeln/process/utils.py index 1af7b0e..8e1414d 100644 --- a/pypeln/process/utils.py +++ b/pypeln/process/utils.py @@ -1,6 +1,10 @@ -import multiprocessing -import multiprocessing.synchronize import typing as tp +import sys + +if "multiprocess" in sys.modules: + from multiprocess import Manager, Lock +else: + from multiprocessing import Manager, Lock from pypeln import utils as pypeln_utils @@ -12,10 +16,10 @@ def __init__(self, **kwargs): global MANAGER if MANAGER is None: - MANAGER = multiprocessing.Manager() + MANAGER = Manager() self.__dict__["_namespace"] = MANAGER.Namespace(**kwargs) - self.__dict__["_lock"] = multiprocessing.Lock() + self.__dict__["_lock"] = Lock() def __getattr__(self, key) -> tp.Any: if key in ("_namespace", "_lock"): diff --git a/pypeln/process/worker.py b/pypeln/process/worker.py index 8a8157a..bae908e 100644 --- a/pypeln/process/worker.py +++ b/pypeln/process/worker.py @@ -1,12 +1,13 @@ -import abc -from copy import copy from dataclasses import dataclass, field -import functools -import multiprocessing -from multiprocessing import synchronize import threading import time import typing as tp +import sys + +if "multiprocess" in sys.modules: + from multiprocess import Process +else: + from multiprocessing import Process import stopit @@ -63,7 +64,7 @@ class Worker(tp.Generic[T]): namespace: utils.Namespace = field( default_factory=lambda: utils.Namespace(done=False, task_start_time=None) ) - process: tp.Optional[tp.Union[multiprocessing.Process, threading.Thread]] = None + process: tp.Optional[tp.Union[Process, threading.Thread]] = None def __call__(self): @@ -138,7 +139,7 @@ def stop(self): if not self.process.is_alive(): return - if isinstance(self.process, multiprocessing.Process): + if isinstance(self.process, Process): self.process.terminate() else: stopit.async_raise( @@ -227,7 +228,7 @@ def start_workers( args: tp.Tuple[tp.Any, ...] = tuple(), kwargs: tp.Optional[tp.Dict[tp.Any, tp.Any]] = None, use_threads: bool = False, -) -> tp.Union[tp.List[multiprocessing.Process], tp.List[threading.Thread]]: +) -> tp.Union[tp.List[Process], tp.List[threading.Thread]]: if kwargs is None: kwargs = {} @@ -237,7 +238,7 @@ def start_workers( if use_threads: t = threading.Thread(target=target, args=args, kwargs=kwargs) else: - t = multiprocessing.Process(target=target, args=args, kwargs=kwargs) + t = Process(target=target, args=args, kwargs=kwargs) t.daemon = True t.start() workers.append(t) diff --git a/pypeln/thread/supervisor.py b/pypeln/thread/supervisor.py index d262cf5..4db9d68 100644 --- a/pypeln/thread/supervisor.py +++ b/pypeln/thread/supervisor.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -import multiprocessing import threading import time import typing as tp