Skip to content

Commit

Permalink
Merge pull request #936 from AI-Hypercomputer:bvandermoon-temp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683240399
  • Loading branch information
maxtext authors committed Oct 7, 2024
2 parents 4aeabf1 + 5ed77bc commit 55dab25
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
2 changes: 2 additions & 0 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,7 @@ def key_rotary(self, key: Array, inputs_positions: Array):
min_timescale=self.config.rope_min_timescale,
max_timescale=self.config.rope_max_timescale,
embedding_dims=self.head_dim,
fprop_dtype=self.dtype,
name="key_rotary",
)(inputs=key, position=inputs_positions)
return key
Expand Down Expand Up @@ -1213,6 +1214,7 @@ def __call__(
min_timescale=self.config.rope_min_timescale,
max_timescale=self.config.rope_max_timescale,
embedding_dims=self.head_dim,
fprop_dtype=self.dtype,
name="query_rotary",
)(inputs=query, position=inputs_positions)
key = self.key_rotary(key, inputs_positions)
Expand Down
50 changes: 50 additions & 0 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,56 @@ def test_autoregression(self):
self.assertTrue(mha_full_this_idx.shape == mha_idx.shape)
self.assertTrue(jax.numpy.allclose(mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False))

@pytest.mark.tpu
def test_model_mode_prefill_dtype_float32(self):
self._test_model_mode_prefill_dtype(jnp.float32)

@pytest.mark.tpu
def test_model_mode_prefill_dtype_bfloat16(self):
self._test_model_mode_prefill_dtype(jnp.bfloat16)

def _test_model_mode_prefill_dtype(self, dtype):
lnx, decoder_segment_ids, decoder_positions = self.get_data(dtype)
prefill_length = self.cfg.max_prefill_predict_length
lnx_prefill = lnx[:, 0:prefill_length, :]
decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length]
decoder_positions_prefill = decoder_positions[:, 0:prefill_length]

attention_as_mha_generic = Attention(
config=self.cfg,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_target_length=self.max_target_length,
max_prefill_predict_length=self.cfg.max_prefill_predict_length,
mesh=self.mesh,
attention_kernel="dot_product",
dtype=dtype,
dropout_rate=self.cfg.dropout_rate,
name="self_attention",
)

attention_as_mha_generic_variable = attention_as_mha_generic.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length)),
)

mha_prefill, _ = attention_as_mha_generic.apply(
attention_as_mha_generic_variable,
lnx_prefill,
lnx_prefill,
decoder_segment_ids=decoder_segment_ids_prefill,
inputs_positions=decoder_positions_prefill,
deterministic=True,
model_mode=common_types.MODEL_MODE_PREFILL,
rngs={"aqt": self.rng},
mutable=["cache"],
)

self.assertEqual(dtype, mha_prefill.dtype)

@pytest.mark.tpu
def test_tpu_kernel_attention_mha(self):
self.tpu_kernel_attention_helper(self.num_kv_heads)
Expand Down

0 comments on commit 55dab25

Please sign in to comment.