From b61b7056227953d94a1434285d28ce1572233d02 Mon Sep 17 00:00:00 2001 From: Grain Team Date: Fri, 31 Jan 2025 05:55:38 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 721744536 --- .../dataset/transformations/testing_util.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/grain/_src/python/dataset/transformations/testing_util.py b/grain/_src/python/dataset/transformations/testing_util.py index 271d9f09..c1bd94ed 100644 --- a/grain/_src/python/dataset/transformations/testing_util.py +++ b/grain/_src/python/dataset/transformations/testing_util.py @@ -9,6 +9,7 @@ from absl.testing import parameterized from grain._src.python.dataset.transformations import packing from grain._src.python.dataset.transformations import source +from jax import numpy as jnp import numpy as np import tree @@ -140,6 +141,92 @@ def test_pack_sequences_length_3(self, num_packing_bins: int): num_packing_bins=num_packing_bins, ) + def test_bfloat16(self): + input_elements = [ + { + "soft_tokens": np.asarray( + [ + [[0.1, 0.2, 0.3]], + [[0.4, 0.5, 0.6]], + [[0.7, 0.8, 0.9]], + ], + dtype=jnp.bfloat16, + ), + }, + { + "soft_tokens": np.asarray( + [ + [[1.1, 1.2, 1.3]], + [[1.4, 1.5, 1.6]], + [[1.7, 1.8, 1.9]], + ], + dtype=jnp.bfloat16, + ), + }, + { + "soft_tokens": np.asarray( + [ + [[2.1, 2.2, 2.3]], + [[2.4, 2.5, 2.6]], + [[2.7, 2.8, 2.9]], + ], + dtype=jnp.bfloat16, + ), + }, + ] + + length_struct = {"soft_tokens": 4} + + expected_elements = [ + { + "soft_tokens": np.asarray( + [ + [[0.1, 0.2, 0.3]], + [[0.4, 0.5, 0.6]], + [[0.7, 0.8, 0.9]], + [[0.0, 0.0, 0.0]], + ], + dtype=jnp.bfloat16, + ), + "soft_tokens_positions": [0, 1, 2, 0], + "soft_tokens_segment_ids": [1, 1, 1, 0], + }, + { + "soft_tokens": np.asarray( + [ + [[1.1, 1.2, 1.3]], + [[1.4, 1.5, 1.6]], + [[1.7, 1.8, 1.9]], + [[0.0, 0.0, 0.0]], + ], + dtype=jnp.bfloat16, + ), + "soft_tokens_positions": [0, 1, 2, 0], + "soft_tokens_segment_ids": [1, 1, 1, 0], + }, + { + "soft_tokens": np.asarray( + [ + [[2.1, 2.2, 2.3]], + [[2.4, 2.5, 2.6]], + [[2.7, 2.8, 2.9]], + [[0.0, 0.0, 0.0]], + ], + dtype=jnp.bfloat16, + ), + "soft_tokens_positions": [0, 1, 2, 0], + "soft_tokens_segment_ids": [1, 1, 1, 0], + }, + ] + + _common_test_body( + input_elements, + expected_elements, + length_struct, + kwargs=self.kwargs, + num_packing_bins=3, + ) + @parameterized.parameters( {"num_packing_bins": 3}, {"num_packing_bins": 5},