Skip to content

Commit

Permalink
add number_worker for parallel (#437)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 authored Feb 19, 2025
1 parent 7aaeaad commit fb8e5dd
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions lazyllm/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,12 @@ class PostProcessType(Enum):
SUM = 4
JOIN = 5

def __init__(self, *args, _scatter=False, _concurrent=True, auto_capture=False, **kw):
def __init__(self, *args, _scatter: bool = False, _concurrent: Union[bool, int] = True,
auto_capture: bool = False, **kw):
super().__init__(*args, **kw, auto_capture=auto_capture)
self._post_process_type = Parallel.PostProcessType.NONE
self._post_process_args = None
self._concurrent = _concurrent
self._concurrent = _concurrent if not isinstance(_concurrent, bool) else 5 if _concurrent else 0
self._scatter = _scatter

@staticmethod
Expand Down Expand Up @@ -396,8 +397,10 @@ def impl(func, barrier, global_data, *args, **kw):
return r

loop, barrier = asyncio.new_event_loop(), threading.Barrier(len(items))
return package(loop.run_until_complete(asyncio.gather(*[loop.run_in_executor(None, partial(
impl, self.invoke, barrier, lazyllm.globals._data, it, inp, **kw)) for it, inp in zip(items, inputs)])))
with concurrent.futures.ThreadPoolExecutor(max_workers=self._concurrent) as executor:
return package(loop.run_until_complete(asyncio.gather(*[loop.run_in_executor(executor, partial(
impl, self.invoke, barrier, lazyllm.globals._data, it, inp, **kw))
for it, inp in zip(items, inputs)])))
else:
return package(self.invoke(it, inp, **kw) for it, inp in zip(items, inputs))

Expand All @@ -420,7 +423,7 @@ def _post_process(self, output):
# (in1, in2, in3) -> in2 -> module21 -> ... -> module2N -> out2 -> (out1, out2, out3)
# \> in3 -> module31 -> ... -> module3N -> out3 /
class Diverter(Parallel):
def __init__(self, *args, _concurrent=True, auto_capture=False, **kw):
def __init__(self, *args, _concurrent: Union[bool, int] = True, auto_capture: bool = False, **kw):
super().__init__(*args, _scatter=True, _concurrent=_concurrent, auto_capture=auto_capture, **kw)


Expand All @@ -430,7 +433,8 @@ def __init__(self, *args, _concurrent=True, auto_capture=False, **kw):
# Attention: Cannot be used in async tasks, ie: training and deploy
# TODO: add check for async tasks
class Warp(Parallel):
def __init__(self, *args, _scatter=False, _concurrent=True, auto_capture=False, **kw):
def __init__(self, *args, _scatter: bool = False, _concurrent: Union[bool, int] = True,
auto_capture: bool = False, **kw):
super().__init__(*args, _scatter=_scatter, _concurrent=_concurrent, auto_capture=auto_capture, **kw)
if len(self._items) > 1: self._items = [Pipeline(*self._items)]

Expand Down

0 comments on commit fb8e5dd

Please sign in to comment.