Skip to content

Commit

Permalink
Add take_events() function (#392)
Browse files Browse the repository at this point in the history
* Add take_events()

* lint streams.py

Co-authored-by: William Barnhart <[email protected]>
Co-authored-by: William Barnhart <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2022
1 parent ecee01b commit 94ff38f
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions faust/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,99 @@ async def add_to_buffer(value: T) -> T:
self.enable_acks = stream_enable_acks
self._processors.remove(add_to_buffer)

async def take_events(
self, max_: int, within: Seconds
) -> AsyncIterable[Sequence[EventT]]:
"""Buffer n events at a time and yield a list of buffered events.
Arguments:
max_: Max number of messages to receive. When more than this
number of messages are received within the specified number of
seconds then we flush the buffer immediately.
within: Timeout for when we give up waiting for another value,
and process the values we have.
Warning: If there's no timeout (i.e. `timeout=None`),
the agent is likely to stall and block buffered events for an
unreasonable length of time(!).
"""
buffer: List[T_co] = []
events: List[EventT] = []
buffer_add = buffer.append
event_add = events.append
buffer_size = buffer.__len__
buffer_full = asyncio.Event()
buffer_consumed = asyncio.Event()
timeout = want_seconds(within) if within else None
stream_enable_acks: bool = self.enable_acks

buffer_consuming: Optional[asyncio.Future] = None

channel_it = aiter(self.channel)

# We add this processor to populate the buffer, and the stream
# is passively consumed in the background (enable_passive below).
async def add_to_buffer(value: T) -> T:
try:
# buffer_consuming is set when consuming buffer after timeout.
nonlocal buffer_consuming
if buffer_consuming is not None:
try:
await buffer_consuming
finally:
buffer_consuming = None
buffer_add(cast(T_co, value))
event = self.current_event
if event is None:
raise RuntimeError("Take buffer found current_event is None")
event_add(event)
if buffer_size() >= max_:
# signal that the buffer is full and should be emptied.
buffer_full.set()
# strict wait for buffer to be consumed after buffer full.
# If max is 1000, we are not allowed to return 1001 values.
buffer_consumed.clear()
await self.wait(buffer_consumed)
except CancelledError: # pragma: no cover
raise
except Exception as exc:
self.log.exception("Error adding to take buffer: %r", exc)
await self.crash(exc)
return value

# Disable acks to ensure this method acks manually
# events only after they are consumed by the user
self.enable_acks = False

self.add_processor(add_to_buffer)
self._enable_passive(cast(ChannelT, channel_it))
try:
while not self.should_stop:
# wait until buffer full, or timeout
await self.wait_for_stopped(buffer_full, timeout=timeout)
if buffer:
# make sure background thread does not add new items to
# buffer while we read.
buffer_consuming = self.loop.create_future()
try:
yield list(events)
finally:
buffer.clear()
for event in events:
await self.ack(event)
events.clear()
# allow writing to buffer again
notify(buffer_consuming)
buffer_full.clear()
buffer_consumed.set()
else: # pragma: no cover
pass
else: # pragma: no cover
pass

finally:
# Restore last behaviour of "enable_acks"
self.enable_acks = stream_enable_acks
self._processors.remove(add_to_buffer)

async def take_with_timestamp(
self, max_: int, within: Seconds, timestamp_field_name: str
) -> AsyncIterable[Sequence[T_co]]:
Expand Down

0 comments on commit 94ff38f

Please sign in to comment.