Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712462394
  • Loading branch information
Grain Team authored and copybara-github committed Jan 28, 2025
1 parent 01a946c commit a6e4378
Show file tree
Hide file tree
Showing 5 changed files with 804 additions and 624 deletions.
12 changes: 12 additions & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ py_test(
deps = ["//grain/_src/python/dataset"],
)

py_library(
name = "testing_util",
testonly = 1,
srcs = ["testing_util.py"],
srcs_version = "PY3",
deps = [
":packing",
"//grain/_src/python/dataset",
],
)

py_library(
name = "flatmap",
srcs = ["flatmap.py"],
Expand Down Expand Up @@ -116,6 +127,7 @@ py_test(
srcs_version = "PY3",
deps = [
":packing",
":testing_util",
"//grain/_src/python/dataset",
],
)
Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import Sequence
import copy
from typing import Any, Optional
from absl import logging
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats
from grain._src.python.dataset.transformations import packing_packed_batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def make_packed_buffer(length: int, x: np.ndarray | int):
shape = ()
dtype = np.int64 if isinstance(x, int) else np.asarray(x).dtype
else:
assert isinstance(x, np.ndarray)
assert isinstance(x, np.ndarray), type(x)
shape = x.shape[1:]
dtype = x.dtype
return zeros(
Expand Down
Loading

0 comments on commit a6e4378

Please sign in to comment.