diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index 680796f2..bcd6abad 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -45,6 +45,7 @@ py_library( py_test( name = "dataset_test", srcs = ["dataset_test.py"], + shard_count = 10, srcs_version = "PY3", deps = [ ":base", diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index 81c66b6e..925cdca1 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -572,9 +572,13 @@ def random_map( The seed can be either provided explicitly or set via `ds.seed(seed)`. Prefer the latter if you don't need to control the random map seed individually. It allows to pass a single seed to derive seeds for all - downstream random transformations in the pipeline. The geenrator is seeded + downstream random transformations in the pipeline. The generator is seeded by a combination of the seed and the index of the element in the dataset. + NOTE: Avoid using the provided RNG outside of the `transform` function + (e.g. by passing it to the next transformation along with the data). + The RNG is going to be reused. + Example usage: ``` ds = MapDataset.range(5) @@ -601,7 +605,7 @@ def random_map( map as map_dataset, ) # pylint: enable=g-import-not-at-top - return map_dataset.MapMapDataset( + return map_dataset.RandomMapMapDataset( parent=self, transform=transform, seed=seed ) @@ -961,6 +965,10 @@ def random_map( by a combination of the seed and a counter of elements produced by the dataset. + NOTE: Avoid using the provided RNG outside of the `transform` function + (e.g. by passing it to the next transformation along with the data). + The RNG is going to be reused. + Example usage: ``` ds = MapDataset.range(5).to_iter_dataset() @@ -987,7 +995,7 @@ def random_map( map as map_dataset, ) # pylint: enable=g-import-not-at-top - return map_dataset.MapIterDataset( + return map_dataset.RandomMapIterDataset( parent=self, transform=transform, seed=seed ) diff --git a/grain/_src/python/dataset/dataset_test.py b/grain/_src/python/dataset/dataset_test.py index 1c846764..d780824a 100644 --- a/grain/_src/python/dataset/dataset_test.py +++ b/grain/_src/python/dataset/dataset_test.py @@ -493,7 +493,7 @@ def test_seed_with_map(self, initial_ds): ds1 = initial_ds.seed(seed).random_map(AddRandomInteger()) ds2 = initial_ds.seed(seed).random_map(AddRandomInteger()) self.assertEqual(list(ds1), list(ds2)) - ds3 = initial_ds.seed(seed + 1).map(AddRandomInteger()) + ds3 = initial_ds.seed(seed + 1).random_map(AddRandomInteger()) self.assertNotEqual(list(ds1), list(ds3)) @parameterized.parameters( diff --git a/grain/_src/python/dataset/transformations/map.py b/grain/_src/python/dataset/transformations/map.py index 69c5aba4..3ecc1777 100644 --- a/grain/_src/python/dataset/transformations/map.py +++ b/grain/_src/python/dataset/transformations/map.py @@ -15,9 +15,8 @@ import functools import threading -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, TypeVar -from absl import logging from grain._src.core import transforms from grain._src.python.dataset import dataset import numpy as np @@ -25,10 +24,6 @@ T = TypeVar("T") # pylint: disable=invalid-name -_MapTransformType = Union[ - transforms.MapTransform, transforms.RandomMapTransform, Callable[..., T] -] - # We need this little helper class to handle RNG generator for random map # transformations. It manages a pool of RNG objects that can be re-used. @@ -75,71 +70,22 @@ def release_rng(self, rng: np.random.Generator): self._generator_cache.append(rng) -def _get_map_fn_and_seed( - transform: _MapTransformType, seed: Optional[int] = None -) -> tuple[Callable[..., T], Optional[int]]: - """Extracts a map fn from `transform`. - - If a seed is returned map fn requires a seed. - - Args: - transform: A (random) map transform as object or callable. - seed: Seed for random transform. Don't pass a seed if the transform is not - random. - - Returns: - Tuple of a callable and a seed. The callable expects the element to be - mapped as first argument. If seed is not None the callable expects a - second argument with a np.random.Generator. - """ - if isinstance(transform, transforms.MapTransform): - if seed is not None: - logging.warning( - "Provided seed for MapTransform %s which doesn't need a seed.", - transform, - ) - return transform.map, None - elif isinstance(transform, transforms.RandomMapTransform): - if seed is None: - raise ValueError( - "RandomMapTransform requires random seed. Please provide it with" - " `ds.seed(seed)`" - ) - return transform.random_map, seed - elif isinstance(transform, transforms.TfRandomMapTransform): - if seed is None: - raise ValueError( - "RandomMapTransform requires random seed. Please provide it with" - " `ds.seed(seed)`" - ) - return transform.np_random_map, seed - else: - # If a `seed` is provided we treat the Callable as RandomMapTransform - return transform, seed - - class MapMapDataset(dataset.MapDataset[T]): - """Map MapDataset.""" + """Map transformation for MapDataset.""" def __init__( self, parent: dataset.MapDataset, - transform: _MapTransformType, - seed: Optional[int] = None, + transform: transforms.MapTransform | Callable[[Any], T], ): super().__init__(parent) - if isinstance( - transform, - (transforms.RandomMapTransform, transforms.TfRandomMapTransform), - ): - seed = self._default_seed if seed is None else seed + if isinstance(transform, transforms.MapTransform): # Use the transform class name. The `cached_property` below will not # be called. self._transform_name = transform.__class__.__name__ - if isinstance(transform, transforms.MapTransform): - self._transform_name = transform.__class__.__name__ - self._map_fn, seed = _get_map_fn_and_seed(transform, seed) - self._rng_pool = None if seed is None else RngPool(seed) + self._map_fn = transform.map + else: + self._map_fn = transform def __len__(self) -> int: return len(self._parent) @@ -158,24 +104,68 @@ def __getitem__(self, index): with self._stats.record_self_time(): if element is None: return None - if self._rng_pool: - rng = self._rng_pool.acquire_rng(index) - element = self._map_fn(element, rng) - self._rng_pool.release_rng(rng) - else: - element = self._map_fn(element) - return self._stats.record_output_spec(element) + return self._stats.record_output_spec(self._map_fn(element)) + + +class RandomMapMapDataset(dataset.MapDataset[T]): + """Random map transformation for MapDataset.""" + + def __init__( + self, + parent: dataset.MapDataset, + transform: ( + transforms.RandomMapTransform + | Callable[[Any, np.random.Generator], T] + ), + seed: int | None = None, + ): + super().__init__(parent) + if isinstance(transform, transforms.RandomMapTransform): + # Use the transform class name. The `cached_property` below will not + # be called. + self._transform_name = transform.__class__.__name__ + self._map_fn = transform.random_map + else: + self._map_fn = transform + seed = self._default_seed if seed is None else seed + if seed is None: + raise ValueError( + "`random_map` requires a seed. Please either provide it with" + " `ds.seed(seed)` before any random transformations or pass it" + " directly with `ds.random_map(transform, seed=seed)`." + ) + self._rng_pool = RngPool(seed) + + def __len__(self) -> int: + return len(self._parent) + + @functools.cached_property + def _transform_name(self): + return transforms.get_pretty_transform_name(self._map_fn) + + def __str__(self) -> str: + return f"RandomMapMapDataset(transform={self._transform_name})" + + def __getitem__(self, index): + if isinstance(index, slice): + return self.slice(index) + element = self._parent[index] + with self._stats.record_self_time(): + if element is None: + return None + rng = self._rng_pool.acquire_rng(index) + element = self._map_fn(element, rng) + self._rng_pool.release_rng(rng) + return self._stats.record_output_spec(element) class MapWithIndexMapDataset(dataset.MapDataset[T]): - """Map with index MapDataset.""" + """Map with index transformation for MapDataset.""" def __init__( self, parent: dataset.MapDataset, - transform: Union[ - transforms.MapWithIndexTransform, Callable[[int, Any], T] - ], + transform: transforms.MapWithIndexTransform | Callable[[int, Any], T], ): super().__init__(parent) if isinstance(transform, transforms.MapWithIndexTransform): @@ -184,7 +174,6 @@ def __init__( # be called. self._transform_name = transform.__class__.__name__ else: - # Expect Callable[[int, Any], T]. self._map_fn = transform @functools.cached_property @@ -198,10 +187,10 @@ def __str__(self) -> str: return f"MapWithIndexMapDataset(transform={self._transform_name})" def __getitem__(self, index): + if isinstance(index, slice): + return self.slice(index) + element = self._parent[index] with self._stats.record_self_time(): - if isinstance(index, slice): - return self.slice(index) - element = self._parent[index] if element is None: return None return self._stats.record_output_spec(self._map_fn(index, element)) @@ -213,8 +202,38 @@ class _MapDatasetIterator(dataset.DatasetIterator[T]): def __init__( self, parent: dataset.DatasetIterator, - map_fn: Callable[..., T], - seed: Optional[int], + map_fn: Callable[[Any], T], + transform_name: str, + ): + super().__init__(parent) + self._map_fn = map_fn + self._transform_name = transform_name + + def __next__(self): + element = next(self._parent) + with self._stats.record_self_time(): + if element is not None: + element = self._map_fn(element) + return self._stats.record_output_spec(element) + + def get_state(self): + return self._parent.get_state() + + def set_state(self, state): + self._parent.set_state(state) + + def __str__(self) -> str: + return f"MapDatasetIterator(transform={self._transform_name})" + + +class _RandomMapDatasetIterator(dataset.DatasetIterator[T]): + """Iterator that applies random map transformation to elements.""" + + def __init__( + self, + parent: dataset.DatasetIterator, + map_fn: Callable[[Any, np.random.Generator], T], + seed: int, transform_name: str, ): super().__init__(parent) @@ -225,25 +244,18 @@ def __init__( self._transform_name = transform_name def __next__(self): - try: - element = next(self._parent) - except StopIteration as e: - raise e - + element = next(self._parent) with self._stats.record_self_time(): if element is not None: - if self._seed is not None: - # Shift index for the current worker process in case of multiprocess - # execution. The actual index value doesn't matter as long as it is - # unique for each process. - index_for_rng = ( - self._ctx.mp_context.process_index - + self._index_for_rng * self._ctx.mp_context.process_count - ) - _reset_rng_state(self._rng, op_seed=0, index=index_for_rng) - element = self._map_fn(element, self._rng) - else: - element = self._map_fn(element) + # Shift index for the current worker process in case of multiprocess + # execution. The actual index value doesn't matter as long as it is + # unique for each process. + index_for_rng = ( + self._ctx.mp_context.process_index + + self._index_for_rng * self._ctx.mp_context.process_count + ) + _reset_rng_state(self._rng, op_seed=0, index=index_for_rng) + element = self._map_fn(element, self._rng) self._index_for_rng += 1 return self._stats.record_output_spec(element) @@ -259,41 +271,78 @@ def set_state(self, state): self._index_for_rng = state["index_for_rng"] def __str__(self) -> str: - return f"MapDatasetIterator(transform={self._transform_name})" + return f"RandomMapDatasetIterator(transform={self._transform_name})" -class MapIterDataset(dataset.IterDataset[T]): - """Map transformation for IterDatasets.""" +class RandomMapIterDataset(dataset.IterDataset[T]): + """Random map transformation for IterDataset.""" def __init__( self, parent: dataset.IterDataset, - transform: _MapTransformType, - seed: Optional[int] = None, + transform: ( + transforms.RandomMapTransform + | Callable[[Any, np.random.Generator], T] + ), + seed: int | None = None, ): super().__init__(parent) - if isinstance( - transform, - (transforms.RandomMapTransform, transforms.TfRandomMapTransform), - ): - seed = self._default_seed if seed is None else seed + if isinstance(transform, transforms.RandomMapTransform): # Use the transform class name. The `cached_property` below will not # be called. self._transform_name = transform.__class__.__name__ + self._map_fn = transform.random_map + else: + self._map_fn = transform + self._seed = self._default_seed if seed is None else seed + if self._seed is None: + raise ValueError( + "`random_map` requires a seed. Please either provide it with" + " `ds.seed(seed)` before any random transformations or pass it" + " directly with `ds.random_map(transform, seed=seed)`." + ) + + @functools.cached_property + def _transform_name(self): + return transforms.get_pretty_transform_name(self._map_fn) + + def __iter__(self) -> _RandomMapDatasetIterator[T]: + return _RandomMapDatasetIterator( + self._parent.__iter__(), + map_fn=self._map_fn, + seed=self._seed, + transform_name=self._transform_name, + ) + + def __str__(self) -> str: + return f"RandomMapIterDataset(transform={self._transform_name})" + + +class MapIterDataset(dataset.IterDataset[T]): + """Map transformation for IterDatasets.""" + + def __init__( + self, + parent: dataset.IterDataset, + transform: transforms.MapTransform | Callable[[Any], T], + ): + super().__init__(parent) if isinstance(transform, transforms.MapTransform): + # Use the transform class name. The `cached_property` below will not + # be called. self._transform_name = transform.__class__.__name__ - self._map_fn, self._seed = _get_map_fn_and_seed(transform, seed) + self._map_fn = transform.map + else: + self._map_fn = transform @functools.cached_property def _transform_name(self): return transforms.get_pretty_transform_name(self._map_fn) def __iter__(self) -> _MapDatasetIterator[T]: - parent_iter = self._parent.__iter__() return _MapDatasetIterator( - parent_iter, + self._parent.__iter__(), map_fn=self._map_fn, - seed=self._seed, transform_name=self._transform_name, ) diff --git a/grain/_src/python/dataset/transformations/map_test.py b/grain/_src/python/dataset/transformations/map_test.py index 2659a678..7b632ad3 100644 --- a/grain/_src/python/dataset/transformations/map_test.py +++ b/grain/_src/python/dataset/transformations/map_test.py @@ -20,7 +20,7 @@ import cloudpickle from grain._src.core import transforms from grain._src.python.dataset import dataset -from grain._src.python.dataset.transformations import map as ldmap +from grain._src.python.dataset.transformations import map as map_ds import numpy as np @@ -33,7 +33,7 @@ def test_reset_rng_state(self): with self.assertRaises(AssertionError): np.testing.assert_equal(rng.bit_generator.state, old_state) - ldmap._reset_rng_state(rng, 0, 0) + map_ds._reset_rng_state(rng, 0, 0) np.testing.assert_equal(rng.bit_generator.state, old_state) @@ -80,31 +80,26 @@ def setUp(self): self.range_ds = dataset.MapDataset.range(0, 10) def test_map_size(self): - map_ds_no_transform = ldmap.MapMapDataset( + map_ds_no_transform = map_ds.MapMapDataset( self.range_ds, MapWithNoTransform() ) - map_ds_with_transform = ldmap.MapMapDataset( + map_ds_with_transform = map_ds.MapMapDataset( self.range_ds, MapWithTransform() ) - map_ds_with_random_transform = ldmap.MapMapDataset( - self.range_ds, RandomMapWithTransform(), seed=0 - ) self.assertLen(map_ds_no_transform, len(self.range_ds)) self.assertLen(map_ds_with_transform, len(self.range_ds)) - self.assertLen(map_ds_with_random_transform, len(self.range_ds)) @parameterized.parameters( MapWithNoTransform, MapWithTransform, - RandomMapWithTransform, ) def test_map_picklable(self, map_cls): - ds = ldmap.MapMapDataset(self.range_ds, map_cls(), seed=0) + ds = map_ds.MapMapDataset(self.range_ds, map_cls()) ds = cloudpickle.loads(cloudpickle.dumps(ds)) self.assertLen(ds, len(self.range_ds)) def test_map_data_no_transform(self): - map_ds_no_transform = ldmap.MapMapDataset( + map_ds_no_transform = map_ds.MapMapDataset( self.range_ds, MapWithNoTransform() ) expected_data = [i for i in range(10)] @@ -114,7 +109,7 @@ def test_map_data_no_transform(self): self.assertEqual(expected_data, actual_data) def test_map_data_with_transform(self): - map_ds_with_transform = ldmap.MapMapDataset( + map_ds_with_transform = map_ds.MapMapDataset( self.range_ds, MapWithTransform() ) expected_data = [i + 1 for i in range(10)] @@ -123,47 +118,70 @@ def test_map_data_with_transform(self): ] self.assertEqual(expected_data, actual_data) + +class RandomMapMapDatasetTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.range_ds = dataset.MapDataset.range(0, 10) + + def test_map_size(self): + ds = map_ds.RandomMapMapDataset( + self.range_ds, RandomMapWithTransform(), seed=0 + ) + self.assertLen(ds, len(self.range_ds)) + + def test_map_picklable(self): + ds = map_ds.RandomMapMapDataset( + self.range_ds, RandomMapWithTransform(), seed=0 + ) + ds = cloudpickle.loads(cloudpickle.dumps(ds)) + self.assertLen(ds, len(self.range_ds)) + def test_random_map_data_with_transform(self): - map_ds_with_random_transform = ldmap.MapMapDataset( + ds = map_ds.RandomMapMapDataset( self.range_ds, RandomMapWithTransform(), seed=0 ) expected_data = [_ for _ in range(10)] - actual_data = [ - map_ds_with_random_transform[i] - for i in range(len(map_ds_with_random_transform)) - ] + actual_data = [ds[i] for i in range(len(ds))] np.testing.assert_almost_equal(expected_data, actual_data, decimal=1) def test_random_map_with_default_seed(self): seed = 42 ds1 = self.range_ds.seed(seed) - ds1 = ldmap.MapMapDataset(ds1, RandomMapWithTransform()) + ds1 = map_ds.RandomMapMapDataset(ds1, RandomMapWithTransform()) ds2 = self.range_ds.seed(seed) - ds2 = ldmap.MapMapDataset(ds2, RandomMapWithTransform()) + ds2 = map_ds.RandomMapMapDataset(ds2, RandomMapWithTransform()) np.testing.assert_almost_equal(list(ds1), list(ds2), decimal=1) def test_random_map_overrides_default_seed(self): seed = 42 ds1 = self.range_ds.seed(seed) - ds1 = ldmap.MapMapDataset(ds1, RandomMapWithTransform()) + ds1 = map_ds.RandomMapMapDataset(ds1, RandomMapWithTransform()) ds2 = self.range_ds.seed(seed) - ds2 = ldmap.MapMapDataset(ds2, RandomMapWithTransform(), seed=seed + 1) + ds2 = map_ds.RandomMapMapDataset( + ds2, RandomMapWithTransform(), seed=seed + 1 + ) np.testing.assert_array_compare(operator.__ne__, list(ds1), list(ds2)) - def test_random_map_raises_with_no_seed(self): + @parameterized.parameters( + RandomMapWithTransform(), + lambda x, rng: x, + ) + def test_random_map_raises_with_no_seed(self, transform): with self.assertRaises(ValueError): - ldmap.MapMapDataset(self.range_ds, RandomMapWithTransform()) + map_ds.RandomMapMapDataset(self.range_ds, transform) -class MapIterDatasetTest(absltest.TestCase): +class MapIterDatasetTest(parameterized.TestCase): def setUp(self): super().setUp() - self.range_iter_ds = dataset.MapDataset.range(0, 10).to_iter_dataset() + self.range_iter_ds = dataset.MapDataset.range(10).to_iter_dataset() def test_map_data_no_transform(self): map_no_transform_iter_ds = iter( - ldmap.MapIterDataset(self.range_iter_ds, MapWithNoTransform()) + map_ds.MapIterDataset(self.range_iter_ds, MapWithNoTransform()) ) expected_data = [_ for _ in range(10)] actual_data = [next(map_no_transform_iter_ds) for _ in range(10)] @@ -171,15 +189,30 @@ def test_map_data_no_transform(self): def test_map_data_with_transform(self): map_with_transform_iter_ds = iter( - ldmap.MapIterDataset(self.range_iter_ds, MapWithTransform()) + map_ds.MapIterDataset(self.range_iter_ds, MapWithTransform()) ) expected_data = [i + 1 for i in range(10)] actual_data = [next(map_with_transform_iter_ds) for _ in range(10)] self.assertEqual(expected_data, actual_data) + def test_map_past_one_epoch_raises_exception(self): + map_no_transform_iter_ds = iter( + map_ds.MapIterDataset(self.range_iter_ds, MapWithNoTransform()) + ) + with self.assertRaises(StopIteration): + next(map_no_transform_iter_ds) + _ = [next(map_no_transform_iter_ds) for _ in range(20)] + + +class RandomMapIterDatasetTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.range_iter_ds = dataset.MapDataset.range(0, 10).to_iter_dataset() + def test_random_map_data_with_transform(self): map_with_random_transform_iter_ds = iter( - ldmap.MapIterDataset( + map_ds.RandomMapIterDataset( self.range_iter_ds, RandomMapWithTransform(), seed=0 ) ) @@ -192,40 +225,38 @@ def test_random_map_data_with_transform_deterministic_with_seed(self): # check if elements are reproducible for multiple runs for _ in range(10): map_with_random_transform_iter_ds = iter( - ldmap.MapIterDataset( + map_ds.RandomMapIterDataset( self.range_iter_ds, RandomMapWithDeterminismTransform(), seed=42 ) ) actual_data = [next(map_with_random_transform_iter_ds) for _ in range(10)] np.testing.assert_equal(expected_data, actual_data) - def test_map_past_one_epoch_raises_exception(self): - map_no_transform_iter_ds = iter( - ldmap.MapIterDataset(self.range_iter_ds, MapWithNoTransform()) - ) - with self.assertRaises(StopIteration): - next(map_no_transform_iter_ds) - _ = [next(map_no_transform_iter_ds) for _ in range(20)] - def test_random_map_with_default_seed(self): seed = 42 ds1 = self.range_iter_ds.seed(seed) - ds1 = ldmap.MapIterDataset(ds1, RandomMapWithTransform()) + ds1 = map_ds.RandomMapIterDataset(ds1, RandomMapWithTransform()) ds2 = self.range_iter_ds.seed(seed) - ds2 = ldmap.MapIterDataset(ds2, RandomMapWithTransform()) + ds2 = map_ds.RandomMapIterDataset(ds2, RandomMapWithTransform()) np.testing.assert_almost_equal(list(ds1), list(ds2), decimal=1) def test_random_map_overrides_default_seed(self): seed = 42 ds1 = self.range_iter_ds.seed(seed) - ds1 = ldmap.MapIterDataset(ds1, RandomMapWithTransform()) + ds1 = map_ds.RandomMapIterDataset(ds1, RandomMapWithTransform()) ds2 = self.range_iter_ds.seed(seed) - ds2 = ldmap.MapIterDataset(ds2, RandomMapWithTransform(), seed=seed + 1) + ds2 = map_ds.RandomMapIterDataset( + ds2, RandomMapWithTransform(), seed=seed + 1 + ) np.testing.assert_array_compare(operator.__ne__, list(ds1), list(ds2)) - def test_random_map_raises_with_no_seed(self): + @parameterized.parameters( + RandomMapWithTransform(), + lambda x, rng: x, + ) + def test_random_map_raises_with_no_seed(self, transform): with self.assertRaises(ValueError): - ldmap.MapIterDataset(self.range_iter_ds, RandomMapWithTransform()) + map_ds.RandomMapIterDataset(self.range_iter_ds, transform) class MapWithIndexMapDatasetTest(absltest.TestCase): @@ -235,17 +266,17 @@ def setUp(self): self.range_ds = dataset.MapDataset.range(3, 6) def test_length(self): - map_ds = ldmap.MapWithIndexMapDataset(self.range_ds, AddIndexTransform()) - self.assertLen(map_ds, len(self.range_ds)) + ds = map_ds.MapWithIndexMapDataset(self.range_ds, AddIndexTransform()) + self.assertLen(ds, len(self.range_ds)) def test_getitem(self): - map_ds = ldmap.MapWithIndexMapDataset(self.range_ds, AddIndexTransform()) - self.assertEqual(map_ds[0], (0, 3)) - self.assertEqual(map_ds[1], (1, 4)) - self.assertEqual(map_ds[2], (2, 5)) - self.assertEqual(map_ds[3], (3, 3)) - self.assertEqual(map_ds[4], (4, 4)) - self.assertEqual(map_ds[5], (5, 5)) + ds = map_ds.MapWithIndexMapDataset(self.range_ds, AddIndexTransform()) + self.assertEqual(ds[0], (0, 3)) + self.assertEqual(ds[1], (1, 4)) + self.assertEqual(ds[2], (2, 5)) + self.assertEqual(ds[3], (3, 3)) + self.assertEqual(ds[4], (4, 4)) + self.assertEqual(ds[5], (5, 5)) if __name__ == "__main__": diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 92ab1ddb..94b7b197 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -27,7 +27,6 @@ from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset.transformations import filter as filter_lazy_dataset -from grain._src.python.dataset.transformations import map as map_lazy_dataset from grain._src.python.dataset.transformations import prefetch import numpy as np @@ -75,8 +74,8 @@ class PrefetchIterDatasetTest(parameterized.TestCase): def setUp(self): super().setUp() self.range_ds = dataset.MapDataset.range(20) - self.filtered_range_ds = filter_lazy_dataset.FilterMapDataset( - self.range_ds, FilterKeepingOddElementsOnly() + self.filtered_range_ds = self.range_ds.filter( + FilterKeepingOddElementsOnly() ) self.prefetch_lazy_iter_ds = prefetch.PrefetchIterDataset( self.range_ds, read_options=options.ReadOptions() @@ -224,9 +223,7 @@ def setUp(self): super().setUp() ds = dataset.MapDataset.range(20) ds = prefetch.PrefetchIterDataset(ds, read_options=options.ReadOptions()) - self.iter_ds = filter_lazy_dataset.FilterIterDataset( - ds, FilterKeepingOddElementsOnly() - ) + self.iter_ds = ds.filter(FilterKeepingOddElementsOnly()) @parameterized.named_parameters( dict( @@ -526,7 +523,7 @@ def map(self, features): return features ds = dataset.MapDataset.range(10) - ds = map_lazy_dataset.MapMapDataset(parent=ds, transform=_SleepTransform()) + ds = ds.map(_SleepTransform()) ds = prefetch.PrefetchIterDataset(ds, read_options=options.ReadOptions()) ds = prefetch.MultiprocessPrefetchIterDataset( ds, @@ -562,7 +559,7 @@ def map(self, features): return features ds = dataset.MapDataset.range(10) - ds = map_lazy_dataset.MapMapDataset(parent=ds, transform=_SleepTransform()) + ds = ds.map(_SleepTransform()) ds = prefetch.PrefetchIterDataset(ds, read_options=options.ReadOptions()) ds = prefetch.MultiprocessPrefetchIterDataset( ds, @@ -614,9 +611,10 @@ class ThreadPrefetchIterDatasetTest(parameterized.TestCase): def setUp(self): super().setUp() - self.ds = filter_lazy_dataset.FilterIterDataset( - dataset.MapDataset.range(20).to_iter_dataset(), - FilterKeepingOddElementsOnly(), + self.ds = ( + dataset.MapDataset.range(20) + .to_iter_dataset() + .filter(FilterKeepingOddElementsOnly()) ) @parameterized.named_parameters( diff --git a/grain/_src/python/dataset/visualize.py b/grain/_src/python/dataset/visualize.py index 4942c68b..aa936317 100644 --- a/grain/_src/python/dataset/visualize.py +++ b/grain/_src/python/dataset/visualize.py @@ -16,17 +16,13 @@ from __future__ import annotations import contextlib -import inspect import pprint from typing import Any, Callable, Generator, TypeVar from grain._src.core import tree_lib from grain._src.python import options from grain._src.python.dataset import dataset -from grain._src.python.dataset.transformations import filter as filter_dataset -from grain._src.python.dataset.transformations import map as map_dataset from grain._src.python.dataset.transformations import prefetch -from grain._src.python.dataset.transformations import source T = TypeVar('T') @@ -39,45 +35,6 @@ """ -def _function_repr(fn: Callable[..., Any]) -> str: - """Produces a human-readable string representation of a function.""" - fn_name = fn.__name__ - if inspect.ismethod(fn): - # Function is defined as a method in a class. Prepend the class name. - class_name = fn.__self__.__class__.__name__ - fn_name = f'{class_name}.{fn_name}' - fn_module = inspect.getmodule(fn).__name__ - _, line_number = inspect.getsourcelines(fn) - return f'{fn_name} in module {fn_module} on line {line_number}' - - -# TODO: Move this to the specific transformations' `__repr__`. Note that -# in that case it could also be used for checkpoint validation and will become -# available to users. -def _dataset_repr(ds: dataset.MapDataset | dataset.IterDataset) -> str: - """Produces a human-readable string representation of a dataset.""" - result = ds.__class__.__name__ - # pylint: disable=protected-access - if isinstance(ds, source.SourceMapDataset): - result += f'' - elif isinstance( - ds, - ( - map_dataset.MapMapDataset, - map_dataset.MapWithIndexMapDataset, - map_dataset.MapIterDataset, - ), - ): - result += f'' - elif isinstance(ds, filter_dataset.FilterMapDataset): - result += f'' - # pylint: enable=protected-access - # Display the number of parents for datasets with multiple parents. - if (num_parents := len(ds.parents)) > 1: - result += f'[{num_parents} parents]' - return result - - class _SpecTrackingMapDataset(dataset.MapDataset[T]): """A MapDataset that tracks spec of its parent outputs.""" @@ -208,10 +165,10 @@ def _build_visualization_from_tracked_spec( # first parent. This is true for `mix`, but is not necessarily true for # `select_from_datasets`. parent_vis = _build_visualization_from_tracked_spec(tracked_ds.parents[0]) - transform_repr = _dataset_repr(tracked_ds) + transform_repr = str(tracked_ds) else: # This dataset tracks the source output. - parent_vis = _dataset_repr(tracked_ds) + parent_vis = str(tracked_ds) transform_repr = '' return _EDGE_TEMPLATE.format( input_spec=parent_vis, @@ -245,7 +202,7 @@ def _patch_dataset_with_spec_tracking( def _patch(ds): if len(ds.parents) > 1: - multiparent_datasets.append(_dataset_repr(ds)) + multiparent_datasets.append(str(ds)) ds._parents = [_patch(p) for p in ds.parents] # pylint: disable=protected-access if isinstance(ds, dataset.MapDataset): return _SpecTrackingMapDataset(ds, mock_source_output) diff --git a/grain/_src/python/dataset/visualize_test.py b/grain/_src/python/dataset/visualize_test.py index 4c72644e..4573a918 100644 --- a/grain/_src/python/dataset/visualize_test.py +++ b/grain/_src/python/dataset/visualize_test.py @@ -22,7 +22,7 @@ import numpy as np -_MAP_DATASET_REPR = r"""RangeMapDataset +_MAP_DATASET_REPR = r"""RangeMapDataset(start=0, stop=10, step=1) ││ ││ ││ @@ -36,7 +36,7 @@ "[]" ││ - ││ MapWithIndexMapDataset + ││ MapWithIndexMapDataset(transform=_add_dummy_metadata @ .../python/dataset/visualize_test.py:XXX) ││ ╲╱ {'data': "[]", @@ -45,7 +45,7 @@ 'index': "[]"} ││ - ││ MapMapDataset + ││ MapMapDataset(transform=_identity @ .../python/dataset/visualize_test.py:XXX) ││ ╲╱ {'data': "[]", @@ -54,7 +54,7 @@ 'index': "[]"} """ -_ITER_DATASET_REPR = r"""RangeMapDataset +_ITER_DATASET_REPR = r"""RangeMapDataset(start=0, stop=10, step=1) ││ ││ ││ @@ -68,13 +68,13 @@ "[]" ││ - ││ PrefetchIterDataset + ││ PrefetchIterDataset(read_options=ReadOptions(num_threads=16, prefetch_buffer_size=500), allow_nones=False) ││ ╲╱ "[]" ││ - ││ MapIterDataset in module __main__ on line XXX> + ││ MapIterDataset(transform= @ .../python/dataset/visualize_test.py:XXX) ││ ╲╱ {'data': "[]", @@ -83,7 +83,7 @@ 'index': "[]"} ││ - ││ BatchIterDataset + ││ BatchIterDataset(batch_size=2, drop_remainder=False) ││ ╲╱ {'data': 'int64[2]', @@ -94,7 +94,7 @@ _MIX_DATASET_REPR = r"""WARNING: Detected multi-parent datasets: MixedMapDataset[2 parents]. Only displaying the first parent. -RangeMapDataset +RangeMapDataset(start=0, stop=10, step=1) ││ ││ ││ @@ -114,13 +114,13 @@ "[]" ││ - ││ MapMapDataset + ││ MapMapDataset(transform=_AddOne) ││ ╲╱ "[]" """ -_PREFETCH_DATASET_REPR = r"""RangeMapDataset +_PREFETCH_DATASET_REPR = r"""RangeMapDataset(start=0, stop=10, step=1) ││ ││ ││ @@ -134,7 +134,7 @@ "[]" ││ - ││ MapWithIndexMapDataset + ││ MapWithIndexMapDataset(transform=_add_dummy_metadata @ .../python/dataset/visualize_test.py:XXX) ││ ╲╱ {'data': "[]", @@ -143,7 +143,7 @@ 'index': "[]"} ││ - ││ PrefetchIterDataset + ││ PrefetchIterDataset(read_options=ReadOptions(num_threads=16, prefetch_buffer_size=500), allow_nones=False) ││ ╲╱ {'data': "[]", @@ -152,7 +152,7 @@ 'index': "[]"} ││ - ││ MultiprocessPrefetchIterDataset + ││ MultiprocessPrefetchIterDataset(multiprocessing_options=MultiprocessingOptions(num_workers=1, per_worker_buffer_size=1, enable_profiling=False)) ││ ╲╱ {'data': "[]", @@ -161,7 +161,7 @@ 'index': "[]"} """ -_SOURCE_DATASET_REPR = r"""SourceMapDataset +_SOURCE_DATASET_REPR = r"""SourceMapDataset(source=SourceMapDataset) ││ ││ ││ @@ -169,7 +169,7 @@ "[]" ││ - ││ MapMapDataset + ││ MapMapDataset(transform=_identity @ .../python/dataset/visualize_test.py:XXX) ││ ╲╱ "[]" @@ -213,7 +213,7 @@ def _assert_visualization(self, ds, expected): result = visualize._build_visualization_str(ds, None) print(result) # Remove line number from the result to make test less brittle. - result = re.sub(r"on line \d+", "on line XXX", result) + result = re.sub(r".py:\d+", ".py:XXX", result) self.assertEqual(result, expected) repr_after_visualize = _deep_dataset_repr(ds) result_after_visualize = list(ds)