From a007b912601ed6692711e25afe1fc543654faa39 Mon Sep 17 00:00:00 2001 From: Philip Monk <169196560+philip-essential@users.noreply.github.com> Date: Fri, 31 Jan 2025 23:55:06 +0000 Subject: [PATCH 1/3] make attention fp32 flags apply during inference --- MaxText/layers/attentions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index f5990e7d9..261b2f12e 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -477,7 +477,7 @@ def apply_attention_dot( """Apply Attention.""" validate_compute_axis_order(self.compute_axis_order) # Casting qk_product and softmaxt computation for float32 for model stability. - if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_qk_product: + if self.float32_qk_product: if isinstance(key, KVTensor): key = key.dequant() query = query.astype(jnp.float32) @@ -491,7 +491,7 @@ def apply_attention_dot( attn_weights = attn_weights * self.attn_logits_soft_cap # Casting softmaxt computation for float32 for model stability. - if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits: + if self.float32_logits: attn_weights = attn_weights.astype(jnp.float32) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) if attn_mask is not None: From 73f6acd02aa30d4585f1956935111aab1ce65562 Mon Sep 17 00:00:00 2001 From: Philip Monk <169196560+philip-essential@users.noreply.github.com> Date: Fri, 31 Jan 2025 23:56:17 +0000 Subject: [PATCH 2/3] move float32_qk_product and float32_logits to config --- MaxText/configs/base.yml | 4 ++++ MaxText/configs/models/gemma-2b.yml | 4 +++- MaxText/configs/models/gemma-7b.yml | 4 +++- MaxText/configs/models/gemma2-27b.yml | 2 ++ MaxText/configs/models/gemma2-2b.yml | 2 ++ MaxText/configs/models/gemma2-9b.yml | 2 ++ MaxText/layers/gemma.py | 4 ++-- MaxText/layers/gemma2.py | 4 ++-- MaxText/layers/gpt3.py | 2 ++ MaxText/layers/llama2.py | 2 ++ MaxText/layers/mistral.py | 2 ++ MaxText/layers/models.py | 2 ++ 12 files changed, 28 insertions(+), 6 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 6ec504b99..ffe639cf2 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -192,6 +192,10 @@ final_logits_soft_cap: 0.0 use_post_attn_norm: False use_post_ffw_norm: False +# In dot_product attention, whether to upcast the qk product and attention logits to fp32 +float32_qk_product: False +float32_logits: False + # Combine matmuls for QKV and MLP fused_qkv: False diff --git a/MaxText/configs/models/gemma-2b.yml b/MaxText/configs/models/gemma-2b.yml index 985b240dc..d7d210b03 100644 --- a/MaxText/configs/models/gemma-2b.yml +++ b/MaxText/configs/models/gemma-2b.yml @@ -24,4 +24,6 @@ mlp_activations: ["gelu","linear"] vocab_size: 256128 decoder_block: "gemma" normalization_layer_epsilon: 1.e-06 -logits_via_embedding: True \ No newline at end of file +logits_via_embedding: True +float32_qk_product: True +float32_qk_logits: True diff --git a/MaxText/configs/models/gemma-7b.yml b/MaxText/configs/models/gemma-7b.yml index 3201b37a5..b07412acd 100644 --- a/MaxText/configs/models/gemma-7b.yml +++ b/MaxText/configs/models/gemma-7b.yml @@ -24,4 +24,6 @@ mlp_activations: ["gelu","linear"] vocab_size: 256128 decoder_block: "gemma" normalization_layer_epsilon: 1.e-06 -logits_via_embedding: True \ No newline at end of file +logits_via_embedding: True +float32_qk_product: True +float32_qk_logits: True diff --git a/MaxText/configs/models/gemma2-27b.yml b/MaxText/configs/models/gemma2-27b.yml index 84a8fb98b..9cb6cde2b 100644 --- a/MaxText/configs/models/gemma2-27b.yml +++ b/MaxText/configs/models/gemma2-27b.yml @@ -30,3 +30,5 @@ attn_logits_soft_cap: 50.0 sliding_window_size: 4096 use_post_attn_norm: True use_post_ffw_norm: True +float32_qk_product: True +float32_qk_logits: True diff --git a/MaxText/configs/models/gemma2-2b.yml b/MaxText/configs/models/gemma2-2b.yml index 8647196b2..a35964238 100644 --- a/MaxText/configs/models/gemma2-2b.yml +++ b/MaxText/configs/models/gemma2-2b.yml @@ -30,3 +30,5 @@ attn_logits_soft_cap: 50.0 sliding_window_size: 4096 use_post_attn_norm: True use_post_ffw_norm: True +float32_qk_product: True +float32_qk_logits: True diff --git a/MaxText/configs/models/gemma2-9b.yml b/MaxText/configs/models/gemma2-9b.yml index d54352b95..74e30b857 100644 --- a/MaxText/configs/models/gemma2-9b.yml +++ b/MaxText/configs/models/gemma2-9b.yml @@ -30,3 +30,5 @@ attn_logits_soft_cap: 50.0 sliding_window_size: 4096 use_post_attn_norm: True use_post_ffw_norm: True +float32_qk_product: True +float32_qk_logits: True diff --git a/MaxText/layers/gemma.py b/MaxText/layers/gemma.py index 52ff40243..051bdf91e 100644 --- a/MaxText/layers/gemma.py +++ b/MaxText/layers/gemma.py @@ -91,8 +91,8 @@ def __call__( weight_dtype=cfg.weight_dtype, dropout_rate=cfg.dropout_rate, name="self_attention", - float32_qk_product=True, - float32_logits=True, + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(cfg), use_ragged_attention=cfg.use_ragged_attention, diff --git a/MaxText/layers/gemma2.py b/MaxText/layers/gemma2.py index dd2db3a58..1acf8cd67 100644 --- a/MaxText/layers/gemma2.py +++ b/MaxText/layers/gemma2.py @@ -91,8 +91,8 @@ def __call__( weight_dtype=cfg.weight_dtype, dropout_rate=cfg.dropout_rate, name="self_attention_local", - float32_qk_product=True, - float32_logits=True, + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(cfg), attention_type=attentions.AttentionType.LOCAL_SLIDING, diff --git a/MaxText/layers/gpt3.py b/MaxText/layers/gpt3.py index e9b6e65e9..cdf07f23c 100644 --- a/MaxText/layers/gpt3.py +++ b/MaxText/layers/gpt3.py @@ -312,6 +312,8 @@ def __call__( mesh=mesh, dropout_rate=cfg.dropout_rate, name="self_attention", + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, fused_qkv=cfg.fused_qkv, use_bias=True, quant=self.quant, diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index 9b198c594..9769edace 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -105,6 +105,8 @@ def __call__( weight_dtype=cfg.weight_dtype, dropout_rate=cfg.dropout_rate, name="self_attention", + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(cfg), prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]), diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index 5fbb9e1e1..85350320c 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -97,6 +97,8 @@ def __call__( weight_dtype=cfg.weight_dtype, dropout_rate=cfg.dropout_rate, name="self_attention", + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(cfg), ) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 4c2046c1f..1b60641c6 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -92,6 +92,8 @@ def __call__( weight_dtype=cfg.weight_dtype, dropout_rate=cfg.dropout_rate, name="self_attention", + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(cfg), prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]), From 66339fabf038d8264de9c0ea1fd9298433a39b10 Mon Sep 17 00:00:00 2001 From: Philip Monk <169196560+philip-essential@users.noreply.github.com> Date: Sat, 1 Feb 2025 01:27:17 +0000 Subject: [PATCH 3/3] address review comments --- MaxText/configs/base.yml | 6 ++---- MaxText/configs/models/gemma-2b.yml | 4 +--- MaxText/configs/models/gemma-7b.yml | 4 +--- MaxText/configs/models/gemma2-27b.yml | 2 -- MaxText/configs/models/gemma2-2b.yml | 2 -- MaxText/configs/models/gemma2-9b.yml | 2 -- 6 files changed, 4 insertions(+), 16 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index ffe639cf2..705aac23e 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -120,6 +120,8 @@ logits_via_embedding: False normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embedding dot product for stability cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly. +float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product +float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax # mixture of experts (moe) num_experts: 1 @@ -192,10 +194,6 @@ final_logits_soft_cap: 0.0 use_post_attn_norm: False use_post_ffw_norm: False -# In dot_product attention, whether to upcast the qk product and attention logits to fp32 -float32_qk_product: False -float32_logits: False - # Combine matmuls for QKV and MLP fused_qkv: False diff --git a/MaxText/configs/models/gemma-2b.yml b/MaxText/configs/models/gemma-2b.yml index d7d210b03..985b240dc 100644 --- a/MaxText/configs/models/gemma-2b.yml +++ b/MaxText/configs/models/gemma-2b.yml @@ -24,6 +24,4 @@ mlp_activations: ["gelu","linear"] vocab_size: 256128 decoder_block: "gemma" normalization_layer_epsilon: 1.e-06 -logits_via_embedding: True -float32_qk_product: True -float32_qk_logits: True +logits_via_embedding: True \ No newline at end of file diff --git a/MaxText/configs/models/gemma-7b.yml b/MaxText/configs/models/gemma-7b.yml index b07412acd..3201b37a5 100644 --- a/MaxText/configs/models/gemma-7b.yml +++ b/MaxText/configs/models/gemma-7b.yml @@ -24,6 +24,4 @@ mlp_activations: ["gelu","linear"] vocab_size: 256128 decoder_block: "gemma" normalization_layer_epsilon: 1.e-06 -logits_via_embedding: True -float32_qk_product: True -float32_qk_logits: True +logits_via_embedding: True \ No newline at end of file diff --git a/MaxText/configs/models/gemma2-27b.yml b/MaxText/configs/models/gemma2-27b.yml index 9cb6cde2b..84a8fb98b 100644 --- a/MaxText/configs/models/gemma2-27b.yml +++ b/MaxText/configs/models/gemma2-27b.yml @@ -30,5 +30,3 @@ attn_logits_soft_cap: 50.0 sliding_window_size: 4096 use_post_attn_norm: True use_post_ffw_norm: True -float32_qk_product: True -float32_qk_logits: True diff --git a/MaxText/configs/models/gemma2-2b.yml b/MaxText/configs/models/gemma2-2b.yml index a35964238..8647196b2 100644 --- a/MaxText/configs/models/gemma2-2b.yml +++ b/MaxText/configs/models/gemma2-2b.yml @@ -30,5 +30,3 @@ attn_logits_soft_cap: 50.0 sliding_window_size: 4096 use_post_attn_norm: True use_post_ffw_norm: True -float32_qk_product: True -float32_qk_logits: True diff --git a/MaxText/configs/models/gemma2-9b.yml b/MaxText/configs/models/gemma2-9b.yml index 74e30b857..d54352b95 100644 --- a/MaxText/configs/models/gemma2-9b.yml +++ b/MaxText/configs/models/gemma2-9b.yml @@ -30,5 +30,3 @@ attn_logits_soft_cap: 50.0 sliding_window_size: 4096 use_post_attn_norm: True use_post_ffw_norm: True -float32_qk_product: True -float32_qk_logits: True