From 5cc034c00017e7e2f6dd14b1c40fc7de36bf9d25 Mon Sep 17 00:00:00 2001 From: Ihor Indyk Date: Thu, 6 Feb 2025 19:26:44 -0800 Subject: [PATCH] Fix hang in `mp_prefetch` when iterator is deleted after start_prefetch. PiperOrigin-RevId: 724167271 --- .../python/dataset/transformations/prefetch_test.py | 11 ++++++++++- grain/_src/python/grain_pool.py | 3 +++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index fee62610..f847a3a4 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -551,7 +551,7 @@ def map(self, features): if not start_prefetch_calls: self.assertGreater(time_to_fetch, 1) - def test_prefetch_but_no_read(self): + def test_prefetch_but_no_read_with_sleep(self): class _SleepTransform(transforms.MapTransform): def map(self, features): @@ -577,6 +577,15 @@ def map(self, features): # buffers. time.sleep(30) + def test_prefetch_but_no_read(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat() + ds = ds.filter(lambda x: x > 3) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch() + it = ds.__iter__() + it.start_prefetch() + del it + def test_prefetch_with_random_map(self): ds = dataset.MapDataset.source([0]).repeat(100).to_iter_dataset() ds = ds.random_map(lambda x, rng: x + rng.integers(sys.maxsize), seed=42) diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py index 53d9ae8e..5b3f1178 100644 --- a/grain/_src/python/grain_pool.py +++ b/grain/_src/python/grain_pool.py @@ -605,6 +605,9 @@ def read_thread_should_stop(): worker_index_to_start_reading=worker_index_to_start_reading, options=multiprocessing_options, ) as g_pool: + # The subprocesses can get shut down before we can pull any elements. + if read_thread_should_stop(): + return for element in g_pool: if read_thread_should_stop(): break