Skip to content

Commit

Permalink
Fix hang in mp_prefetch when iterator is deleted after start_prefetch.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 724167271
  • Loading branch information
iindyk authored and copybara-github committed Feb 7, 2025
1 parent 3d5baa8 commit 5cc034c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
11 changes: 10 additions & 1 deletion grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def map(self, features):
if not start_prefetch_calls:
self.assertGreater(time_to_fetch, 1)

def test_prefetch_but_no_read(self):
def test_prefetch_but_no_read_with_sleep(self):
class _SleepTransform(transforms.MapTransform):

def map(self, features):
Expand All @@ -577,6 +577,15 @@ def map(self, features):
# buffers.
time.sleep(30)

def test_prefetch_but_no_read(self):
ds = dataset.MapDataset.source([1, 2, 3]).repeat()
ds = ds.filter(lambda x: x > 3)
ds = ds.to_iter_dataset()
ds = ds.mp_prefetch()
it = ds.__iter__()
it.start_prefetch()
del it

def test_prefetch_with_random_map(self):
ds = dataset.MapDataset.source([0]).repeat(100).to_iter_dataset()
ds = ds.random_map(lambda x, rng: x + rng.integers(sys.maxsize), seed=42)
Expand Down
3 changes: 3 additions & 0 deletions grain/_src/python/grain_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ def read_thread_should_stop():
worker_index_to_start_reading=worker_index_to_start_reading,
options=multiprocessing_options,
) as g_pool:
# The subprocesses can get shut down before we can pull any elements.
if read_thread_should_stop():
return
for element in g_pool:
if read_thread_should_stop():
break
Expand Down

0 comments on commit 5cc034c

Please sign in to comment.