Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721744536
  • Loading branch information
Grain Team authored and copybara-github committed Jan 31, 2025
1 parent c736394 commit b61b705
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions grain/_src/python/dataset/transformations/testing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from absl.testing import parameterized
from grain._src.python.dataset.transformations import packing
from grain._src.python.dataset.transformations import source
from jax import numpy as jnp
import numpy as np
import tree

Expand Down Expand Up @@ -140,6 +141,92 @@ def test_pack_sequences_length_3(self, num_packing_bins: int):
num_packing_bins=num_packing_bins,
)

def test_bfloat16(self):
input_elements = [
{
"soft_tokens": np.asarray(
[
[[0.1, 0.2, 0.3]],
[[0.4, 0.5, 0.6]],
[[0.7, 0.8, 0.9]],
],
dtype=jnp.bfloat16,
),
},
{
"soft_tokens": np.asarray(
[
[[1.1, 1.2, 1.3]],
[[1.4, 1.5, 1.6]],
[[1.7, 1.8, 1.9]],
],
dtype=jnp.bfloat16,
),
},
{
"soft_tokens": np.asarray(
[
[[2.1, 2.2, 2.3]],
[[2.4, 2.5, 2.6]],
[[2.7, 2.8, 2.9]],
],
dtype=jnp.bfloat16,
),
},
]

length_struct = {"soft_tokens": 4}

expected_elements = [
{
"soft_tokens": np.asarray(
[
[[0.1, 0.2, 0.3]],
[[0.4, 0.5, 0.6]],
[[0.7, 0.8, 0.9]],
[[0.0, 0.0, 0.0]],
],
dtype=jnp.bfloat16,
),
"soft_tokens_positions": [0, 1, 2, 0],
"soft_tokens_segment_ids": [1, 1, 1, 0],
},
{
"soft_tokens": np.asarray(
[
[[1.1, 1.2, 1.3]],
[[1.4, 1.5, 1.6]],
[[1.7, 1.8, 1.9]],
[[0.0, 0.0, 0.0]],
],
dtype=jnp.bfloat16,
),
"soft_tokens_positions": [0, 1, 2, 0],
"soft_tokens_segment_ids": [1, 1, 1, 0],
},
{
"soft_tokens": np.asarray(
[
[[2.1, 2.2, 2.3]],
[[2.4, 2.5, 2.6]],
[[2.7, 2.8, 2.9]],
[[0.0, 0.0, 0.0]],
],
dtype=jnp.bfloat16,
),
"soft_tokens_positions": [0, 1, 2, 0],
"soft_tokens_segment_ids": [1, 1, 1, 0],
},
]

_common_test_body(
input_elements,
expected_elements,
length_struct,
kwargs=self.kwargs,
num_packing_bins=3,
)

@parameterized.parameters(
{"num_packing_bins": 3},
{"num_packing_bins": 5},
Expand Down

0 comments on commit b61b705

Please sign in to comment.