Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712462394
  • Loading branch information
Grain Team authored and copybara-github committed Jan 6, 2025
1 parent 20f1174 commit f4e8cc7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
5 changes: 4 additions & 1 deletion grain/_src/python/dataset/transformations/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from collections.abc import Sequence
import copy
from typing import Any, Optional
from absl import logging
from grain._src.cpp.transformations.packing.python import packing
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats
from grain._src.python.dataset.transformations import packing_packed_batch
from jaxtyping import PyTree # pylint: disable=g-importing-member
import numpy as np
import tree
Expand Down Expand Up @@ -289,6 +290,7 @@ def __init__(
num_packing_bins: int,
shuffle_bins: bool = True,
meta_features: Sequence[str] = (),
use_cc_version: bool = False,
):
"""Creates a dataset that packs sequences from the parent dataset.
Expand All @@ -300,6 +302,7 @@ def __init__(
shuffle_bins: Whether to shuffle bins after packing.
meta_features: Meta features that do not need *_segment_ids and
*_positions features.
use_cc_version: Whether to use the faster C++ implementation.
"""
super().__init__(parent)
self._length_struct = length_struct
Expand Down
35 changes: 28 additions & 7 deletions grain/_src/python/dataset/transformations/packing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,10 @@ def test_pack_sequences_length_3(self, num_packing_bins: int):
{"num_packing_bins": 3},
{"num_packing_bins": 5},
)
def test_pack_sequences_length_shuffle_bins(self, num_packing_bins: int):
def test_pack_sequences_length_shuffle_bins(
self,
num_packing_bins: int,
):
input_elements = [
{
"inputs": [1, 2, 3],
Expand Down Expand Up @@ -541,7 +544,10 @@ def test_pack_sequences_length_4(self):
]

_common_test_body(
input_elements, expected_elements, length_struct, num_packing_bins=2
input_elements,
expected_elements,
length_struct,
num_packing_bins=2,
)

def test_pack_sequences_length_5(self):
Expand Down Expand Up @@ -581,7 +587,10 @@ def test_pack_sequences_length_5(self):
]

_common_test_body(
input_elements, expected_elements, length_struct, num_packing_bins=2
input_elements,
expected_elements,
length_struct,
num_packing_bins=2,
)

def test_pack_sequences_length_6(self):
Expand Down Expand Up @@ -611,7 +620,10 @@ def test_pack_sequences_length_6(self):
}]

_common_test_body(
input_elements, expected_elements, length_struct, num_packing_bins=2
input_elements,
expected_elements,
length_struct,
num_packing_bins=2,
)

def test_pack_sequences_length_7(self):
Expand Down Expand Up @@ -641,7 +653,10 @@ def test_pack_sequences_length_7(self):
}]

_common_test_body(
input_elements, expected_elements, length_struct, num_packing_bins=1
input_elements,
expected_elements,
length_struct,
num_packing_bins=1,
)

def test_pack_sequences_different_lengths(self):
Expand Down Expand Up @@ -688,7 +703,10 @@ def test_pack_sequences_different_lengths(self):
},
]
_common_test_body(
input_elements, expected_elements, length_struct, num_packing_bins=3
input_elements,
expected_elements,
length_struct,
num_packing_bins=3,
)

def test_pack_sequences_two_dimensional_features(self):
Expand Down Expand Up @@ -738,7 +756,10 @@ def test_pack_sequences_two_dimensional_features(self):
]

_common_test_body(
input_elements, expected_elements, length_struct, num_packing_bins=2
input_elements,
expected_elements,
length_struct,
num_packing_bins=2,
)

@parameterized.parameters(
Expand Down

0 comments on commit f4e8cc7

Please sign in to comment.