diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index fee62610..f847a3a4 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -551,7 +551,7 @@ def map(self, features): if not start_prefetch_calls: self.assertGreater(time_to_fetch, 1) - def test_prefetch_but_no_read(self): + def test_prefetch_but_no_read_with_sleep(self): class _SleepTransform(transforms.MapTransform): def map(self, features): @@ -577,6 +577,15 @@ def map(self, features): # buffers. time.sleep(30) + def test_prefetch_but_no_read(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat() + ds = ds.filter(lambda x: x > 3) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch() + it = ds.__iter__() + it.start_prefetch() + del it + def test_prefetch_with_random_map(self): ds = dataset.MapDataset.source([0]).repeat(100).to_iter_dataset() ds = ds.random_map(lambda x, rng: x + rng.integers(sys.maxsize), seed=42) diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py index 53d9ae8e..5b3f1178 100644 --- a/grain/_src/python/grain_pool.py +++ b/grain/_src/python/grain_pool.py @@ -605,6 +605,9 @@ def read_thread_should_stop(): worker_index_to_start_reading=worker_index_to_start_reading, options=multiprocessing_options, ) as g_pool: + # The subprocesses can get shut down before we can pull any elements. + if read_thread_should_stop(): + return for element in g_pool: if read_thread_should_stop(): break