Skip to content

Commit

Permalink
.skip/.truncate: make no-op if arguments are None (closes #57);…
Browse files Browse the repository at this point in the history
… `.skip`: ok when both `count` and `until` are set (#59)
  • Loading branch information
ebonnal committed Jan 24, 2025
1 parent 58539af commit cb66477
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 55 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ five_first_integers: Stream[int] = integers.truncate(when=lambda n: n == 5)
assert list(five_first_integers) == [0, 1, 2, 3, 4]
```

> If both `count` and `when` are set, truncation occurs as soon as either condition is met.
## `.skip`

> Skips the first specified number of elements:
Expand All @@ -365,6 +367,8 @@ integers_after_five: Stream[int] = integers.skip(until=lambda n: n >= 5)
assert list(integers_after_five) == [5, 6, 7, 8, 9]
```

> If both `count` and `until` are set, skipping stops as soon as either condition is met.
## `.catch`

> Catches a given type of exceptions, and optionally yields a `replacement` value:
Expand Down
16 changes: 9 additions & 7 deletions streamable/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CatchIterator,
ConcurrentFlattenIterator,
ConsecutiveDistinctIterator,
CountAndPredicateSkipIterator,
CountSkipIterator,
CountTruncateIterator,
DistinctIterator,
Expand All @@ -41,10 +42,9 @@
validate_group_interval,
validate_group_size,
validate_iterator,
validate_skip_args,
validate_optional_count,
validate_throttle_interval,
validate_throttle_per_period,
validate_truncate_args,
)

with suppress(ImportError):
Expand Down Expand Up @@ -170,11 +170,13 @@ def skip(
until: Optional[Callable[[T], Any]] = None,
) -> Iterator[T]:
validate_iterator(iterator)
validate_skip_args(count, until)
validate_optional_count(count)
if until is not None:
iterator = PredicateSkipIterator(iterator, until)
elif count is not None:
iterator = CountSkipIterator(iterator, count)
if count is not None:
return CountAndPredicateSkipIterator(iterator, count, until)
return PredicateSkipIterator(iterator, until)
if count is not None:
return CountSkipIterator(iterator, count)
return iterator


Expand Down Expand Up @@ -210,7 +212,7 @@ def truncate(
when: Optional[Callable[[T], Any]] = None,
) -> Iterator[T]:
validate_iterator(iterator)
validate_truncate_args(count, when)
validate_optional_count(count)
if count is not None:
iterator = CountTruncateIterator(iterator, count)
if when is not None:
Expand Down
23 changes: 23 additions & 0 deletions streamable/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,29 @@ def __next__(self) -> T:
return elem


class CountAndPredicateSkipIterator(Iterator[T]):
def __init__(
self, iterator: Iterator[T], count: int, until: Callable[[T], Any]
) -> None:
validate_iterator(iterator)
validate_count(count)
self.iterator = iterator
self.count = count
self.until = wrap_error(until, StopIteration)
self._n_skipped = 0
self._done_skipping = False

def __next__(self) -> T:
elem = next(self.iterator)
if not self._done_skipping:
while self._n_skipped < self.count and not self.until(elem):
elem = next(self.iterator)
# do not count exceptions as skipped elements
self._n_skipped += 1
self._done_skipping = True
return elem


class CountTruncateIterator(Iterator[T]):
def __init__(self, iterator: Iterator[T], count: int) -> None:
validate_iterator(iterator)
Expand Down
15 changes: 8 additions & 7 deletions streamable/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
validate_concurrency,
validate_group_interval,
validate_group_size,
validate_skip_args,
validate_optional_count,
validate_throttle_interval,
validate_throttle_per_period,
validate_truncate_args,
validate_via,
)

Expand Down Expand Up @@ -433,16 +432,17 @@ def skip(
self, count: Optional[int] = None, until: Optional[Callable[[T], Any]] = None
) -> "Stream[T]":
"""
Skips the first `count` elements, or skips `until` a predicate becomes satisfied.
Skips elements until `until(elem)` is truthy, or `count` elements have been skipped.
If both `count` and `until` are set, skipping stops as soon as either condition is met.
Args:
count (Optional[int], optional): The number of elements to skip. (default: no count-based skipping)
count (Optional[int], optional): The maximum number of elements to skip. (default: no count-based skipping)
until (Optional[Callable[[T], Any]], optional): Elements are skipped until the first one for which `until(elem)` is truthy. This element and all the subsequent ones will be yielded. (default: no predicate-based skipping)
Returns:
Stream: A stream of the upstream elements remaining after skipping.
"""
validate_skip_args(count, until)
validate_optional_count(count)
return SkipStream(self, count, until)

def throttle(
Expand Down Expand Up @@ -480,7 +480,8 @@ def truncate(
self, count: Optional[int] = None, when: Optional[Callable[[T], Any]] = None
) -> "Stream[T]":
"""
Stops an iteration as soon as the `when` predicate is satisfied or `count` elements have been yielded.
Stops an iteration as soon as `when(elem)` is truthy, or `count` elements have been yielded.
If both `count` and `when` are set, truncation occurs as soon as either condition is met.
Args:
count (int, optional): The maximum number of elements to yield. (default: no count-based truncation)
Expand All @@ -489,7 +490,7 @@ def truncate(
Returns:
Stream[T]: A stream of at most `count` upstream elements not satisfying the `when` predicate.
"""
validate_truncate_args(count, when)
validate_optional_count(count)
return TruncateStream(self, count, when)


Expand Down
29 changes: 6 additions & 23 deletions streamable/util/validationtools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import sys
from typing import Any, Callable, Iterator, Optional, TypeVar
from typing import Iterator, Optional, TypeVar

T = TypeVar("T")

Expand Down Expand Up @@ -47,6 +47,11 @@ def validate_count(count: int):
raise ValueError(f"`count` must be < sys.maxsize but got {count}")


def validate_optional_count(count: Optional[int]):
if count is not None:
validate_count(count)


def validate_throttle_per_period(per_period_arg_name: str, value: int) -> None:
if value < 1:
raise ValueError(f"`{per_period_arg_name}` must be >= 1 but got {value}")
Expand All @@ -55,25 +60,3 @@ def validate_throttle_per_period(per_period_arg_name: str, value: int) -> None:
def validate_throttle_interval(interval: datetime.timedelta) -> None:
if interval < datetime.timedelta(0):
raise ValueError(f"`interval` must be >= 0 but got {repr(interval)}")


def validate_truncate_args(
count: Optional[int] = None, when: Optional[Callable[[T], Any]] = None
) -> None:
if count is None:
if when is None:
raise ValueError("`count` and `when` cannot both be None")
else:
validate_count(count)


def validate_skip_args(
count: Optional[int] = None, until: Optional[Callable[[T], Any]] = None
) -> None:
if count is None:
if until is None:
raise ValueError("`count` and `until` cannot both be None")
else:
if until is not None:
raise ValueError("`count` and `until` cannot both be set")
validate_count(count)
53 changes: 35 additions & 18 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,19 +818,17 @@ def test_skip(self) -> None:
):
Stream(src).skip(-1)

with self.assertRaisesRegex(
ValueError,
"`count` and `until` cannot both be set",
msg="`skip` must raise ValueError if both `count` and `until` are set",
):
Stream(src).skip(0, until=bool)
self.assertListEqual(
list(Stream(src).skip()),
list(src),
msg="`skip` must be no-op if both `count` and `until` are None",
)

with self.assertRaisesRegex(
ValueError,
"`count` and `until` cannot both be None",
msg="`skip` must raise ValueError if both `count` and `until` are None",
):
Stream(src).skip()
self.assertListEqual(
list(Stream(src).skip(None)),
list(src),
msg="`skip` must be no-op if both `count` and `until` are None",
)

for count in [0, 1, 3]:
self.assertListEqual(
Expand All @@ -855,24 +853,43 @@ def test_skip(self) -> None:
msg="`skip` must yield starting from the first element satisfying `until`",
)

self.assertListEqual(
list(Stream(src).skip(count, until=lambda n: False)),
list(src)[count:],
msg="`skip` must ignore `count` elements if `until` is never satisfied",
)

self.assertListEqual(
list(Stream(src).skip(count * 2, until=lambda n: n >= count)),
list(src)[count:],
msg="`skip` must ignore less than `count` elements if `until` is satisfied first",
)

self.assertListEqual(
list(Stream(src).skip(until=lambda n: False)),
[],
msg="`skip` must not yield any element if `until` is never satisfied",
)

def test_truncate(self) -> None:
with self.assertRaisesRegex(
ValueError,
"`count` and `when` cannot both be None",
):
Stream(src).truncate()

self.assertListEqual(
list(Stream(src).truncate(N * 2)),
list(src),
msg="`truncate` must be ok with count >= stream length",
)

self.assertListEqual(
list(Stream(src).truncate()),
list(src),
msg="`truncate must be no-op if both `count` and `when` are None",
)

self.assertListEqual(
list(Stream(src).truncate(None)),
list(src),
msg="`truncate must be no-op if both `count` and `when` are None",
)

self.assertListEqual(
list(Stream(src).truncate(2)),
[0, 1],
Expand Down

0 comments on commit cb66477

Please sign in to comment.