diff --git a/lazyllm/flow/flow.py b/lazyllm/flow/flow.py index e928a145..2d8c94a0 100644 --- a/lazyllm/flow/flow.py +++ b/lazyllm/flow/flow.py @@ -347,11 +347,12 @@ class PostProcessType(Enum): SUM = 4 JOIN = 5 - def __init__(self, *args, _scatter=False, _concurrent=True, auto_capture=False, **kw): + def __init__(self, *args, _scatter: bool = False, _concurrent: Union[bool, int] = True, + auto_capture: bool = False, **kw): super().__init__(*args, **kw, auto_capture=auto_capture) self._post_process_type = Parallel.PostProcessType.NONE self._post_process_args = None - self._concurrent = _concurrent + self._concurrent = _concurrent if not isinstance(_concurrent, bool) else 5 if _concurrent else 0 self._scatter = _scatter @staticmethod @@ -396,8 +397,10 @@ def impl(func, barrier, global_data, *args, **kw): return r loop, barrier = asyncio.new_event_loop(), threading.Barrier(len(items)) - return package(loop.run_until_complete(asyncio.gather(*[loop.run_in_executor(None, partial( - impl, self.invoke, barrier, lazyllm.globals._data, it, inp, **kw)) for it, inp in zip(items, inputs)]))) + with concurrent.futures.ThreadPoolExecutor(max_workers=self._concurrent) as executor: + return package(loop.run_until_complete(asyncio.gather(*[loop.run_in_executor(executor, partial( + impl, self.invoke, barrier, lazyllm.globals._data, it, inp, **kw)) + for it, inp in zip(items, inputs)]))) else: return package(self.invoke(it, inp, **kw) for it, inp in zip(items, inputs)) @@ -420,7 +423,7 @@ def _post_process(self, output): # (in1, in2, in3) -> in2 -> module21 -> ... -> module2N -> out2 -> (out1, out2, out3) # \> in3 -> module31 -> ... -> module3N -> out3 / class Diverter(Parallel): - def __init__(self, *args, _concurrent=True, auto_capture=False, **kw): + def __init__(self, *args, _concurrent: Union[bool, int] = True, auto_capture: bool = False, **kw): super().__init__(*args, _scatter=True, _concurrent=_concurrent, auto_capture=auto_capture, **kw) @@ -430,7 +433,8 @@ def __init__(self, *args, _concurrent=True, auto_capture=False, **kw): # Attention: Cannot be used in async tasks, ie: training and deploy # TODO: add check for async tasks class Warp(Parallel): - def __init__(self, *args, _scatter=False, _concurrent=True, auto_capture=False, **kw): + def __init__(self, *args, _scatter: bool = False, _concurrent: Union[bool, int] = True, + auto_capture: bool = False, **kw): super().__init__(*args, _scatter=_scatter, _concurrent=_concurrent, auto_capture=auto_capture, **kw) if len(self._items) > 1: self._items = [Pipeline(*self._items)]