Skip to content

Commit

Permalink
Merge pull request #1225 from Essential-AI:fp32_attention
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725390928
  • Loading branch information
maxtext authors committed Feb 11, 2025
2 parents aaf467e + 66339fa commit 0ab2455
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 6 deletions.
2 changes: 2 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ logits_via_embedding: False
normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true
logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embedding dot product for stability
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax

# mixture of experts (moe)
num_experts: 1
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def apply_attention_dot(
"""Apply Attention."""
validate_compute_axis_order(self.compute_axis_order)
# Casting qk_product and softmaxt computation for float32 for model stability.
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_qk_product:
if self.float32_qk_product:
if isinstance(key, KVTensor):
key = key.dequant()
query = query.astype(jnp.float32)
Expand All @@ -491,7 +491,7 @@ def apply_attention_dot(
attn_weights = attn_weights * self.attn_logits_soft_cap

# Casting softmaxt computation for float32 for model stability.
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits:
if self.float32_logits:
attn_weights = attn_weights.astype(jnp.float32)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
if attn_mask is not None:
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=True,
float32_logits=True,
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
use_ragged_attention=cfg.use_ragged_attention,
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention_local",
float32_qk_product=True,
float32_logits=True,
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
attention_type=attentions.AttentionType.LOCAL_SLIDING,
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def __call__(
mesh=mesh,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
fused_qkv=cfg.fused_qkv,
use_bias=True,
quant=self.quant,
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
)
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
Expand Down

0 comments on commit 0ab2455

Please sign in to comment.