Skip to content

Commit

Permalink
Fix bug with ds.seed + ds.random_map when passed a callable.
Browse files Browse the repository at this point in the history
`random_map` would not recognize the seed and would raise cryptic error. While at it, refactored map transformation implementation to be less error prone and more readable.

PiperOrigin-RevId: 721405525
  • Loading branch information
iindyk authored and copybara-github committed Jan 30, 2025
1 parent df8ca35 commit 277a6e1
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 238 deletions.
1 change: 1 addition & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ py_library(
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
shard_count = 10,
srcs_version = "PY3",
deps = [
":base",
Expand Down
14 changes: 11 additions & 3 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,13 @@ def random_map(
The seed can be either provided explicitly or set via `ds.seed(seed)`.
Prefer the latter if you don't need to control the random map seed
individually. It allows to pass a single seed to derive seeds for all
downstream random transformations in the pipeline. The geenrator is seeded
downstream random transformations in the pipeline. The generator is seeded
by a combination of the seed and the index of the element in the dataset.
NOTE: Avoid using the provided RNG outside of the `transform` function
(e.g. by passing it to the next transformation along with the data).
The RNG is going to be reused.
Example usage:
```
ds = MapDataset.range(5)
Expand All @@ -601,7 +605,7 @@ def random_map(
map as map_dataset,
)
# pylint: enable=g-import-not-at-top
return map_dataset.MapMapDataset(
return map_dataset.RandomMapMapDataset(
parent=self, transform=transform, seed=seed
)

Expand Down Expand Up @@ -961,6 +965,10 @@ def random_map(
by a combination of the seed and a counter of elements produced by the
dataset.
NOTE: Avoid using the provided RNG outside of the `transform` function
(e.g. by passing it to the next transformation along with the data).
The RNG is going to be reused.
Example usage:
```
ds = MapDataset.range(5).to_iter_dataset()
Expand All @@ -987,7 +995,7 @@ def random_map(
map as map_dataset,
)
# pylint: enable=g-import-not-at-top
return map_dataset.MapIterDataset(
return map_dataset.RandomMapIterDataset(
parent=self, transform=transform, seed=seed
)

Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def test_seed_with_map(self, initial_ds):
ds1 = initial_ds.seed(seed).random_map(AddRandomInteger())
ds2 = initial_ds.seed(seed).random_map(AddRandomInteger())
self.assertEqual(list(ds1), list(ds2))
ds3 = initial_ds.seed(seed + 1).map(AddRandomInteger())
ds3 = initial_ds.seed(seed + 1).random_map(AddRandomInteger())
self.assertNotEqual(list(ds1), list(ds3))

@parameterized.parameters(
Expand Down
Loading

0 comments on commit 277a6e1

Please sign in to comment.