From 3a046a4cf7d99c56853957a5dc11749921d18fcb Mon Sep 17 00:00:00 2001 From: geetu040 Date: Wed, 22 Jan 2025 11:42:36 +0500 Subject: [PATCH 01/11] end-to-end architecture --- .../models/minimax_text_01/__init__.py | 68 + .../configuration_minimax_text_01.py | 169 ++ .../modeling_minimax_text_01.py | 1604 +++++++++++++++++ .../modular_minimax_text_01.py | 520 ++++++ 4 files changed, 2361 insertions(+) create mode 100644 src/transformers/models/minimax_text_01/__init__.py create mode 100644 src/transformers/models/minimax_text_01/configuration_minimax_text_01.py create mode 100644 src/transformers/models/minimax_text_01/modeling_minimax_text_01.py create mode 100644 src/transformers/models/minimax_text_01/modular_minimax_text_01.py diff --git a/src/transformers/models/minimax_text_01/__init__.py b/src/transformers/models/minimax_text_01/__init__.py new file mode 100644 index 000000000000..1d65a515cf17 --- /dev/null +++ b/src/transformers/models/minimax_text_01/__init__.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_minimax_text_01": ["MiniMaxText01Config"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_minimax_text_01"] = [ + "MiniMaxText01ForCausalLM", + "MiniMaxText01ForQuestionAnswering", + "MiniMaxText01Model", + "MiniMaxText01PreTrainedModel", + "MiniMaxText01ForSequenceClassification", + "MiniMaxText01ForTokenClassification", + ] + + +if TYPE_CHECKING: + from .configuration_minimax_text_01 import MiniMaxText01Config + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_minimax_text_01 import ( + MiniMaxText01ForCausalLM, + MiniMaxText01ForQuestionAnswering, + MiniMaxText01ForSequenceClassification, + MiniMaxText01ForTokenClassification, + MiniMaxText01Model, + MiniMaxText01PreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py b/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py new file mode 100644 index 000000000000..00ce0f3919e9 --- /dev/null +++ b/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py @@ -0,0 +1,169 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax_text_01/modular_minimax_text_01.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax_text_01.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PretrainedConfig + + +class MiniMaxText01Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MiniMaxText01Model`]. It is used to instantiate an + MiniMaxText01 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the MiniMaxText01. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MiniMaxText01 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MiniMaxText01Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. MiniMaxText01's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. + + ```python + >>> from transformers import MiniMaxText01Model, MiniMaxText01Config + + >>> # Initializing a MiniMaxText01 style configuration + >>> configuration = MiniMaxText01Config() + + >>> # Initializing a model from the MiniMaxText01 style configuration + >>> model = MiniMaxText01Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "minimax_text_01" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=None, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=None, + eos_token_id=None, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py new file mode 100644 index 000000000000..7bc7194783c3 --- /dev/null +++ b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py @@ -0,0 +1,1604 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax_text_01/modular_minimax_text_01.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax_text_01.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_minimax_text_01 import MiniMaxText01Config + + +logger = logging.get_logger(__name__) + + +_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1" +_CONFIG_FOR_DOC = "MixtralConfig" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # print() + # ic(module.layer_idx) + # show_tensor(query, False, True) + # show_tensor(key_states, False, True) + # show_tensor(value_states, False, True) + # show_tensor(attn_weights, False, True) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class MiniMaxText01Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MiniMaxText01Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # print(self.layer_idx) + # show_tensor(query_states, end=False, only_shapes=False) + # show_tensor(key_states, end=False, only_shapes=True) + # show_tensor(value_states, end=True, only_shapes=True) + + # print() + # print() + # ic(self.layer_idx) + # show_tensor(key_states, False, True) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # show_tensor(key_states, False, True) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + # ic(self.layer_idx) + # show_tensor(attn_output, False, True) + + return attn_output, attn_weights + + +def get_slopes(head_dim): + equ = lambda x: 1 / (2 ** (8 / x)) + + log2 = math.log2(head_dim) + if log2.is_integer(): + return [equ(head_dim) ** i for i in range(1, head_dim + 1)] + + lower_bound = 2 ** math.floor(log2) + upper_bound = 2 ** math.ceil(log2) + + lower_bound_slopes = get_slopes(lower_bound) + upper_bound_slopes = get_slopes(upper_bound) + slopes = lower_bound_slopes + upper_bound_slopes[::2][: head_dim - lower_bound] + + return slopes + + +# TODO: clean and refactor +# TODO: lightning + eager = attention_mask is not None = fails +def lightning_attention_forward( + module, + query_states, + key_states, + value_states, + attention_mask, + **kwargs, +): + batch_size, hidden_size, seq_len, head_dim = query_states.shape + batch_size, hidden_size, seq_len, kv_head_dim = value_states.shape + + BLOCK = 256 + num_blocks = (seq_len + BLOCK - 1) // BLOCK + + if attention_mask is not None: + value_states = value_states.masked_fill((1 - attention_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0) + + slope_rate = get_slopes(head_dim) + slope_rate = torch.tensor(slope_rate, device=query_states.device, dtype=torch.float32) + # TODO: check for a different batch size + slope_rate = slope_rate.unsqueeze(1).unsqueeze(1) + slope_rate *= 1 - module.layer_idx / (module.num_hidden_layers - 1) + 1e-5 + + array = torch.arange(BLOCK).to(query_states) + 1 + query_states_decay = torch.exp(-slope_rate * array.reshape(-1, 1)) + key_states_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1))) + index = array[:, None] - array[None, :] + s_index = ( + slope_rate + * index[ + None, + None, + ] + ) + s_index = torch.where(index >= 0, -s_index, float("-inf")) + diag_decay = torch.exp(s_index) + + # TODO: remove unused kv + kv = torch.zeros(batch_size, hidden_size, head_dim, kv_head_dim).to(torch.float32).to(query_states.device) + output = torch.empty( + (batch_size, hidden_size, seq_len, kv_head_dim), dtype=query_states.dtype, device=query_states.device + ) + + for i in range(num_blocks): + si = i * BLOCK + ei = min(si + BLOCK, seq_len) + m = ei - si + query_states_i = query_states[:, :, si:ei].contiguous() + key_states_i = key_states[:, :, si:ei].contiguous() + value_states_i = value_states[:, :, si:ei].contiguous() + qkv_none_diag = torch.matmul(query_states_i * query_states_decay[:, :m], kv).to(torch.float32) + + # diag + qk = torch.matmul(query_states_i, key_states_i.transpose(-1, -2)).to(torch.float32) * diag_decay[:, :, :m, :m] + qkv_diag = torch.matmul(qk, value_states_i.to(torch.float32)) + block_decay = torch.exp(-slope_rate * m) + output[:, :, si:ei] = qkv_none_diag + qkv_diag + kv = block_decay * kv + torch.matmul( + (key_states_i * key_states_decay[:, -m:]).transpose(-1, -2).to(value_states_i.dtype), value_states_i + ) + + return output, None + + +# TODO +class MiniMaxText01LightningAttention(nn.Module): + def __init__(self, config: MiniMaxText01Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_heads = config.num_attention_heads + self.num_hidden_layers = config.num_hidden_layers + + self.act_fn = ACT2FN[config.hidden_act] + self.norm = MixtralRMSNorm(self.head_dim * self.num_heads) + self.qkv_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim * 3, bias=False) + # TODO: separate q,k,v + # self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + # self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + # self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + self.output_gate = nn.Linear(config.hidden_size, self.num_heads * self.num_heads, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + # TODO: separate q,k,v + # query_states = self.act_fn(self.q_proj(hidden_states)).view(hidden_shape).transpose(1, 2) + # key_states = self.act_fn(self.k_proj(hidden_states)).view(hidden_shape).transpose(1, 2) + # value_states = self.act_fn(self.v_proj(hidden_states)).view(hidden_shape).transpose(1, 2) + + qkv_mixed = self.act_fn(self.qkv_proj(hidden_states)) + new_shape = qkv_mixed.size()[:-1] + (self.num_heads, -1) + qkv_mixed = qkv_mixed.view(*new_shape) + query_states, key_states, value_states = torch.split(qkv_mixed, [self.head_dim] * 3, dim=3) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if past_key_value is not None: + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # show_tensor(key_states, False, True) + + # TODO: store following computed in cache + attn_output, attn_weights = lightning_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + **kwargs, + ) + + attn_output = rearrange(attn_output, "b h n d -> b n (h d)") + attn_output = self.norm(attn_output) + attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output + attn_output = self.out_proj(attn_output) + + # ic(self.layer_idx) + # show_tensor(attn_output, False, True) + + return attn_output, attn_weights + + +class MiniMaxText01BlockSparseTop2MLP(nn.Module): + def __init__(self, config: MiniMaxText01Config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MiniMaxText01SparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MiniMaxText01BlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MiniMaxText01RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniMaxText01RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MiniMaxText01DecoderLayer(nn.Module): + def __init__(self, config: MiniMaxText01Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MiniMaxText01Attention(config, layer_idx) + + self.block_sparse_moe = MiniMaxText01SparseMoeBlock(config) + self.input_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # TODO: add each of these to config + self.residual_post_norm = getattr(config, "residual_post_norm", False) + self.layernorm_attention_alpha = getattr(config, "layernorm_attention_alpha", 1) + self.layernorm_attention_beta = getattr(config, "layernorm_attention_beta", 1) + self.layernorm_lightning_attention_alpha = getattr(config, "layernorm_lightning_attention_alpha", 1) + self.layernorm_lightning_attention_beta = getattr(config, "layernorm_lightning_attention_beta", 1) + self.layernorm_mlp_alpha = getattr(config, "layernorm_mlp_alpha", 1) + self.layernorm_mlp_beta = getattr(config, "layernorm_mlp_beta", 1) + + # TODO: remove these + self.layer_idx = layer_idx + self.residual_post_norm = True + self.layernorm_attention_alpha = 3.5565588200778455 + self.layernorm_attention_beta = 1.0 + self.layernorm_lightning_attention_alpha = 3.5565588200778455 + self.layernorm_lightning_attention_beta = 1.0 + self.layernorm_mlp_alpha = 3.5565588200778455 + self.layernorm_mlp_beta = 1.0 + + # TODO: attn_type_list to config + if config.attn_type_list[layer_idx] == 0: + self.self_attn = MiniMaxText01LightningAttention(config, layer_idx) + self.layernorm_alpha = self.layernorm_lightning_attention_alpha + self.layernorm_beta = self.layernorm_lightning_attention_beta + else: + self.self_attn = MiniMaxText01Attention(config, layer_idx) + self.layernorm_alpha = self.layernorm_attention_alpha + self.layernorm_beta = self.layernorm_attention_beta + + # TODO: shared_moe + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + # print() + # ic(self.layer_idx) + # show_tensor(hidden_states, False, True) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + if self.residual_post_norm: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual * self.layernorm_alpha + hidden_states * self.layernorm_beta + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if self.residual_post_norm: + residual = hidden_states + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual * self.layernorm_mlp_alpha + hidden_states * self.layernorm_mlp_beta + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + # show_tensor(hidden_states, False, True) + + return outputs + + +class MiniMaxText01RotaryEmbedding(nn.Module): + def __init__(self, config: MiniMaxText01Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +MINI_MAX_TEXT01_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MiniMaxText01Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.", + MINI_MAX_TEXT01_START_DOCSTRING, +) +class MiniMaxText01PreTrainedModel(PreTrainedModel): + config_class = MiniMaxText01Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MiniMaxText01DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MINI_MAX_TEXT01_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.", + MINI_MAX_TEXT01_START_DOCSTRING, +) +class MiniMaxText01Model(MiniMaxText01PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniMaxText01DecoderLayer`] + + Args: + config: MiniMaxText01Config + """ + + def __init__(self, config: MiniMaxText01Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MiniMaxText01DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MiniMaxText01RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of MiniMaxText01. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: MiniMaxText01Config, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`MiniMaxText01Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class MiniMaxText01ForCausalLM(MiniMaxText01PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = MiniMaxText01Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MiniMaxText01ForCausalLM + + >>> model = MiniMaxText01ForCausalLM.from_pretrained("mistralai/MiniMaxText01-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/MiniMaxText01-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + # ic(input_ids.shape, input_ids) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +@add_start_docstrings( + """ + The MiniMaxText01 Model transformer with a sequence classification head on top (linear layer). + + [`MiniMaxText01ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MINI_MAX_TEXT01_START_DOCSTRING, +) +class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MiniMaxText01Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The MiniMaxText01 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + MINI_MAX_TEXT01_START_DOCSTRING, +) +class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MiniMaxText01Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The MiniMaxText01 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MINI_MAX_TEXT01_START_DOCSTRING, +) +class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel): + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.model = MiniMaxText01Model(config) # diff with Llama: transformer->model + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py new file mode 100644 index 000000000000..ce00c4dd1081 --- /dev/null +++ b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py @@ -0,0 +1,520 @@ +# coding=utf-8 +# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MiniMax-Text-01 model.""" + +# TODO: remove these +from icecream import ic +from pack_minimax import show_tensor + +from typing import Callable, List, Optional, Tuple, Union + +import math +from einops import rearrange +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import ( + logging, +) +from ..mixtral.modeling_mixtral import ( + eager_attention_forward, + MixtralRMSNorm, + MixtralAttention, + MixtralDecoderLayer, + MixtralModel, + MixtralForCausalLM, + MixtralForSequenceClassification, + MixtralForTokenClassification, + MixtralForQuestionAnswering, +) + + +logger = logging.get_logger(__name__) + + +_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1" +_CONFIG_FOR_DOC = "MixtralConfig" + + +class MiniMaxText01Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MiniMaxText01Model`]. It is used to instantiate an + MiniMaxText01 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the MiniMaxText01. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MiniMaxText01 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MiniMaxText01Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. MiniMaxText01's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. + + ```python + >>> from transformers import MiniMaxText01Model, MiniMaxText01Config + + >>> # Initializing a MiniMaxText01 style configuration + >>> configuration = MiniMaxText01Config() + + >>> # Initializing a model from the MiniMaxText01 style configuration + >>> model = MiniMaxText01Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "minimax_text_01" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=None, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=None, + eos_token_id=None, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + +# ---------------------- +# TODO: see if rotary_emb works at Model level rather than attention level +# checked: it works + + +# TODO +class MiniMaxText01Attention(MixtralAttention): + pass + + +def get_slopes(head_dim): + equ = lambda x: 1 / (2 ** (8/x)) + + log2 = math.log2(head_dim) + if log2.is_integer(): + return [equ(head_dim) ** i for i in range(1, head_dim+1)] + + lower_bound = 2 ** math.floor(log2) + upper_bound = 2 ** math.ceil(log2) + + lower_bound_slopes = get_slopes(lower_bound) + upper_bound_slopes = get_slopes(upper_bound) + slopes = lower_bound_slopes + upper_bound_slopes[::2][:head_dim-lower_bound] + + return slopes + + +# TODO: clean and refactor +# TODO: lightning + eager = attention_mask is not None = fails +def lightning_attention_forward( + module, + query_states, + key_states, + value_states, + attention_mask, + **kwargs, + ): + + batch_size, hidden_size, seq_len, head_dim = query_states.shape + batch_size, hidden_size, seq_len, kv_head_dim = value_states.shape + + BLOCK = 256 + num_blocks = (seq_len + BLOCK - 1) // BLOCK + + if attention_mask is not None: + value_states = value_states.masked_fill((1 - attention_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0) + + slope_rate = get_slopes(head_dim) + slope_rate = torch.tensor(slope_rate, device=query_states.device, dtype=torch.float32) + # TODO: check for a different batch size + slope_rate = slope_rate.unsqueeze(1).unsqueeze(1) + slope_rate *= 1 - module.layer_idx / (module.num_hidden_layers - 1) + 1e-5 + + array = torch.arange(BLOCK).to(query_states) + 1 + query_states_decay = torch.exp(-slope_rate * array.reshape(-1, 1)) + key_states_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1))) + index = array[:, None] - array[None, :] + s_index = slope_rate * index[ + None, + None, + ] + s_index = torch.where(index >= 0, -s_index, float("-inf")) + diag_decay = torch.exp(s_index) + + # TODO: remove unused kv + kv = torch.zeros(batch_size, hidden_size, head_dim, kv_head_dim).to(torch.float32).to(query_states.device) + output = torch.empty( + (batch_size, hidden_size, seq_len, kv_head_dim), + dtype=query_states.dtype, + device=query_states.device + ) + + for i in range(num_blocks): + si = i * BLOCK + ei = min(si + BLOCK, seq_len) + m = ei - si + query_states_i = query_states[:, :, si:ei].contiguous() + key_states_i = key_states[:, :, si:ei].contiguous() + value_states_i = value_states[:, :, si:ei].contiguous() + qkv_none_diag = torch.matmul( + query_states_i * query_states_decay[:, :m], kv + ).to(torch.float32) + + # diag + qk = torch.matmul( + query_states_i, + key_states_i.transpose(-1, -2) + ).to(torch.float32) * diag_decay[:, :, :m, :m] + qkv_diag = torch.matmul(qk, value_states_i.to(torch.float32)) + block_decay = torch.exp(-slope_rate * m) + output[:, :, si:ei] = qkv_none_diag + qkv_diag + kv = ( + block_decay * kv + + + torch.matmul( + ( + key_states_i * key_states_decay[:, -m:] + ).transpose(-1, -2).to(value_states_i.dtype), + value_states_i + ) + ) + + return output, None + + +# TODO +class MiniMaxText01LightningAttention(nn.Module): + def __init__(self, config: MiniMaxText01Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_heads = config.num_attention_heads + self.num_hidden_layers = config.num_hidden_layers + + self.act_fn = ACT2FN[config.hidden_act] + self.norm = MixtralRMSNorm(self.head_dim * self.num_heads) + self.qkv_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim * 3, bias=False) + # TODO: separate q,k,v + # self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + # self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + # self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + self.output_gate = nn.Linear(config.hidden_size, self.num_heads * self.num_heads, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + # TODO: separate q,k,v + # query_states = self.act_fn(self.q_proj(hidden_states)).view(hidden_shape).transpose(1, 2) + # key_states = self.act_fn(self.k_proj(hidden_states)).view(hidden_shape).transpose(1, 2) + # value_states = self.act_fn(self.v_proj(hidden_states)).view(hidden_shape).transpose(1, 2) + + qkv_mixed = self.act_fn(self.qkv_proj(hidden_states)) + new_shape = qkv_mixed.size()[:-1] + (self.num_heads, -1) + qkv_mixed = qkv_mixed.view(*new_shape) + query_states, key_states, value_states = torch.split(qkv_mixed, [self.head_dim] * 3, dim=3) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if past_key_value is not None: + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # show_tensor(key_states, False, True) + + # TODO: store following computed in cache + attn_output, attn_weights = lightning_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + **kwargs, + ) + + attn_output = rearrange(attn_output, "b h n d -> b n (h d)") + attn_output = self.norm(attn_output) + attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output + attn_output = self.out_proj(attn_output) + + # ic(self.layer_idx) + # show_tensor(attn_output, False, True) + + return attn_output, attn_weights + + +class MiniMaxText01DecoderLayer(MixtralDecoderLayer): + def __init__(self, config: MiniMaxText01Config, layer_idx: int): + super().__init__(config, layer_idx) + + # TODO: add each of these to config + self.residual_post_norm = getattr(config, "residual_post_norm", False) + self.layernorm_attention_alpha = getattr(config, "layernorm_attention_alpha", 1) + self.layernorm_attention_beta = getattr(config, "layernorm_attention_beta", 1) + self.layernorm_lightning_attention_alpha = getattr(config, "layernorm_lightning_attention_alpha", 1) + self.layernorm_lightning_attention_beta = getattr(config, "layernorm_lightning_attention_beta", 1) + self.layernorm_mlp_alpha = getattr(config, "layernorm_mlp_alpha", 1) + self.layernorm_mlp_beta = getattr(config, "layernorm_mlp_beta", 1) + + # TODO: remove these + self.layer_idx = layer_idx + self.residual_post_norm = True + self.layernorm_attention_alpha = 3.5565588200778455 + self.layernorm_attention_beta = 1.0 + self.layernorm_lightning_attention_alpha = 3.5565588200778455 + self.layernorm_lightning_attention_beta = 1.0 + self.layernorm_mlp_alpha = 3.5565588200778455 + self.layernorm_mlp_beta = 1.0 + + # TODO: attn_type_list to config + if config.attn_type_list[layer_idx] == 0: + self.self_attn = MiniMaxText01LightningAttention(config, layer_idx) + self.layernorm_alpha = self.layernorm_lightning_attention_alpha + self.layernorm_beta = self.layernorm_lightning_attention_beta + else: + self.self_attn = MiniMaxText01Attention(config, layer_idx) + self.layernorm_alpha = self.layernorm_attention_alpha + self.layernorm_beta = self.layernorm_attention_beta + + # TODO: shared_moe + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + # print() + # ic(self.layer_idx) + # show_tensor(hidden_states, False, True) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + if self.residual_post_norm: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual * self.layernorm_alpha + hidden_states * self.layernorm_beta + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if self.residual_post_norm: + residual = hidden_states + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual * self.layernorm_mlp_alpha + hidden_states * self.layernorm_mlp_beta + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + # show_tensor(hidden_states, False, True) + + return outputs + + +class MiniMaxText01Model(MixtralModel): + def __init__(self, config: MiniMaxText01Config): + super().__init__(config) + self.layers = nn.ModuleList( + [MiniMaxText01DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + +class MiniMaxText01ForCausalLM(MixtralForCausalLM): + def __init__(self, config): + super().__init__(config) + self.model = MiniMaxText01Model(config) + + +class MiniMaxText01ForSequenceClassification(MixtralForSequenceClassification): + pass + + +class MiniMaxText01ForTokenClassification(MixtralForTokenClassification): + pass + + +class MiniMaxText01ForQuestionAnswering(MixtralForQuestionAnswering): + pass From 519eda3d2f5b2c1400095c3d10695043878f3c0d Mon Sep 17 00:00:00 2001 From: geetu040 Date: Fri, 24 Jan 2025 01:19:24 +0500 Subject: [PATCH 02/11] lightning-attn: refactor, clean, optimize --- .../configuration_minimax_text_01.py | 41 ++- .../modeling_minimax_text_01.py | 328 ++++++++---------- .../modular_minimax_text_01.py | 292 +++++++--------- 3 files changed, 313 insertions(+), 348 deletions(-) diff --git a/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py b/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py index 00ce0f3919e9..a1235ffbb278 100644 --- a/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py @@ -132,6 +132,15 @@ def __init__( output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.0, + attn_type_list=None, + block_size=256, + residual_post_norm=False, + layernorm_attention_alpha=1, + layernorm_attention_beta=1, + layernorm_lightning_attention_alpha=1, + layernorm_lightning_attention_beta=1, + layernorm_mlp_alpha=1, + layernorm_mlp_beta=1, **kwargs, ): self.vocab_size = vocab_size @@ -141,11 +150,6 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range @@ -160,6 +164,33 @@ def __init__( self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef self.router_jitter_noise = router_jitter_noise + + # use softmax-attention after `interval` lightning-attentions + interval = num_hidden_layers // 10 + self.attn_type_list = ( + [1 if i % interval == interval - 1 else 0 for i in range(num_hidden_layers)] + if attn_type_list is None + else attn_type_list + ) + + self.block_size = block_size + self.residual_post_norm = residual_post_norm + self.layernorm_attention_alpha = layernorm_attention_alpha + self.layernorm_attention_beta = layernorm_attention_beta + self.layernorm_lightning_attention_alpha = layernorm_lightning_attention_alpha + self.layernorm_lightning_attention_beta = layernorm_lightning_attention_beta + self.layernorm_mlp_alpha = layernorm_mlp_alpha + self.layernorm_mlp_beta = layernorm_mlp_beta + + # TODO: move these to saved config + self.residual_post_norm = True + self.layernorm_attention_alpha = 3.5565588200778455 + self.layernorm_attention_beta = 1.0 + self.layernorm_lightning_attention_alpha = 3.5565588200778455 + self.layernorm_lightning_attention_beta = 1.0 + self.layernorm_mlp_alpha = 3.5565588200778455 + self.layernorm_mlp_beta = 1.0 + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py index 7bc7194783c3..9e33a18a58a3 100644 --- a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py @@ -20,12 +20,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn.functional as F -from einops import rearrange from torch import nn from ...activations import ACT2FN @@ -63,6 +61,140 @@ _CONFIG_FOR_DOC = "MixtralConfig" +class MiniMaxText01LightningAttentionDecay(nn.Module): + def __init__(self, config: MiniMaxText01Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_heads = config.num_attention_heads + self.num_hidden_layers = config.num_hidden_layers + self.block_size = config.block_size + + def forward(self, x, seq_len): + num_blocks = (seq_len + self.block_size - 1) // self.block_size + padding = num_blocks * self.block_size - seq_len + + num_heads_range = torch.arange(self.num_heads).to(x) + 1 + block_size_range = torch.arange(self.block_size).to(x) + 1 + + slope_rate = (1 / (2 ** (8 / self.num_heads))) ** num_heads_range + slope_rate *= 1 - self.layer_idx / (self.num_hidden_layers - 1) + 1e-5 # check small addition + slope_rate = slope_rate[:, None, None] + + query_decay = torch.exp(-slope_rate * block_size_range[:, None]) + query_decay = query_decay[:, None, :, :] + + key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None])) + key_decay = key_decay[:, None, :, :] + key_decay = key_decay.repeat(1, num_blocks, 1, 1) + key_decay[:, -1, : self.block_size - padding] = key_decay[:, -1, padding:] + + diagonal_decay = block_size_range[:, None] - block_size_range[None, :] + diagonal_decay = slope_rate * diagonal_decay[None, :, :] + diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf")) + diagonal_decay = torch.exp(diagonal_decay) + diagonal_decay = diagonal_decay[:, None, :, :] + + block_lengths = torch.cat( + (torch.full((num_blocks - 1,), self.block_size), torch.tensor([self.block_size - padding])) + ).to(x) + block_decay = torch.exp(-slope_rate[:, None, :, :] * block_lengths[:, None, None]) + + return key_decay, query_decay, diagonal_decay, block_decay + + +class MiniMaxText01LightningAttention(nn.Module): + def __init__(self, config: MiniMaxText01Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_heads = config.num_attention_heads + self.num_hidden_layers = config.num_hidden_layers + self.block_size = config.block_size + + self.act_fn = ACT2FN[config.hidden_act] + self.norm = MixtralRMSNorm(self.head_dim * self.num_heads) + self.qkv_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim * 3, bias=False) + self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + self.output_gate = nn.Linear(config.hidden_size, self.num_heads * self.num_heads, bias=False) + self.decay_factors = MiniMaxText01LightningAttentionDecay(config, layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, seq_len, hidden_size = hidden_states.shape + num_blocks = (seq_len + self.block_size - 1) // self.block_size + padding = num_blocks * self.block_size - seq_len + + qkv_states = self.act_fn(self.qkv_proj(hidden_states)) + qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_heads, 3 * self.head_dim) + + query_states, key_states, value_states = torch.split(qkv_states, [self.head_dim] * 3, dim=3) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # TODO: apply attention_mask + + query_states = F.pad(query_states, (0, 0, 0, padding)) + key_states = F.pad(key_states, (0, 0, 0, padding)) + value_states = F.pad(value_states, (0, 0, 0, padding)) + + query_states = query_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) + key_states = key_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) + value_states = value_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) + + # TODO: get from past_key_value[layer_idx] + next_cache = torch.zeros(batch_size, self.num_heads, 1, self.head_dim, self.head_dim).to(value_states) + + # get decay factors + key_decay, query_decay, diagonal_decay, block_decay = self.decay_factors(query_states, seq_len) + + # intra: ( Q @ K.T ) @ V -> QK * V + attn_weights_intra = torch.matmul(query_states, key_states.transpose(-1, -2)) + attn_output_intra = torch.matmul(attn_weights_intra * diagonal_decay, value_states) + + # inter: Q @ ( K.T @ V ) -> Q * KV + attn_weights_inter = torch.matmul((key_states * key_decay).transpose(-1, -2), value_states) + attn_weights_inter = torch.cat([next_cache, attn_weights_inter], dim=2) + for i in range(num_blocks): + attn_weights_inter[:, :, i + 1, :, :] += attn_weights_inter[:, :, i, :, :] * block_decay[:, i, :, :] + next_cache = attn_weights_inter[:, :, -1, :, :] + attn_weights_inter = attn_weights_inter[:, :, :-1, :, :] + attn_output_inter = torch.matmul(query_states * query_decay, attn_weights_inter) + + # inter + intra + attn_output = attn_output_inter + attn_output_intra + attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len + padding, self.head_dim) + attn_output = attn_output[:, :, :seq_len, :] + + # final output projection + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, seq_len, self.num_heads * self.head_dim) + attn_output = self.norm(attn_output) + attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output + attn_output = self.out_proj(attn_output) + + # TODO: put to past_key_value[layer_idx] + next_cache + + # TODO: remove these + print() + print(self.layer_idx) + print(next_cache) + + return attn_output, None + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -226,161 +358,6 @@ def forward( return attn_output, attn_weights -def get_slopes(head_dim): - equ = lambda x: 1 / (2 ** (8 / x)) - - log2 = math.log2(head_dim) - if log2.is_integer(): - return [equ(head_dim) ** i for i in range(1, head_dim + 1)] - - lower_bound = 2 ** math.floor(log2) - upper_bound = 2 ** math.ceil(log2) - - lower_bound_slopes = get_slopes(lower_bound) - upper_bound_slopes = get_slopes(upper_bound) - slopes = lower_bound_slopes + upper_bound_slopes[::2][: head_dim - lower_bound] - - return slopes - - -# TODO: clean and refactor -# TODO: lightning + eager = attention_mask is not None = fails -def lightning_attention_forward( - module, - query_states, - key_states, - value_states, - attention_mask, - **kwargs, -): - batch_size, hidden_size, seq_len, head_dim = query_states.shape - batch_size, hidden_size, seq_len, kv_head_dim = value_states.shape - - BLOCK = 256 - num_blocks = (seq_len + BLOCK - 1) // BLOCK - - if attention_mask is not None: - value_states = value_states.masked_fill((1 - attention_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0) - - slope_rate = get_slopes(head_dim) - slope_rate = torch.tensor(slope_rate, device=query_states.device, dtype=torch.float32) - # TODO: check for a different batch size - slope_rate = slope_rate.unsqueeze(1).unsqueeze(1) - slope_rate *= 1 - module.layer_idx / (module.num_hidden_layers - 1) + 1e-5 - - array = torch.arange(BLOCK).to(query_states) + 1 - query_states_decay = torch.exp(-slope_rate * array.reshape(-1, 1)) - key_states_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1))) - index = array[:, None] - array[None, :] - s_index = ( - slope_rate - * index[ - None, - None, - ] - ) - s_index = torch.where(index >= 0, -s_index, float("-inf")) - diag_decay = torch.exp(s_index) - - # TODO: remove unused kv - kv = torch.zeros(batch_size, hidden_size, head_dim, kv_head_dim).to(torch.float32).to(query_states.device) - output = torch.empty( - (batch_size, hidden_size, seq_len, kv_head_dim), dtype=query_states.dtype, device=query_states.device - ) - - for i in range(num_blocks): - si = i * BLOCK - ei = min(si + BLOCK, seq_len) - m = ei - si - query_states_i = query_states[:, :, si:ei].contiguous() - key_states_i = key_states[:, :, si:ei].contiguous() - value_states_i = value_states[:, :, si:ei].contiguous() - qkv_none_diag = torch.matmul(query_states_i * query_states_decay[:, :m], kv).to(torch.float32) - - # diag - qk = torch.matmul(query_states_i, key_states_i.transpose(-1, -2)).to(torch.float32) * diag_decay[:, :, :m, :m] - qkv_diag = torch.matmul(qk, value_states_i.to(torch.float32)) - block_decay = torch.exp(-slope_rate * m) - output[:, :, si:ei] = qkv_none_diag + qkv_diag - kv = block_decay * kv + torch.matmul( - (key_states_i * key_states_decay[:, -m:]).transpose(-1, -2).to(value_states_i.dtype), value_states_i - ) - - return output, None - - -# TODO -class MiniMaxText01LightningAttention(nn.Module): - def __init__(self, config: MiniMaxText01Config, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_heads = config.num_attention_heads - self.num_hidden_layers = config.num_hidden_layers - - self.act_fn = ACT2FN[config.hidden_act] - self.norm = MixtralRMSNorm(self.head_dim * self.num_heads) - self.qkv_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim * 3, bias=False) - # TODO: separate q,k,v - # self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - # self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - # self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - self.output_gate = nn.Linear(config.hidden_size, self.num_heads * self.num_heads, bias=False) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - # TODO: separate q,k,v - # query_states = self.act_fn(self.q_proj(hidden_states)).view(hidden_shape).transpose(1, 2) - # key_states = self.act_fn(self.k_proj(hidden_states)).view(hidden_shape).transpose(1, 2) - # value_states = self.act_fn(self.v_proj(hidden_states)).view(hidden_shape).transpose(1, 2) - - qkv_mixed = self.act_fn(self.qkv_proj(hidden_states)) - new_shape = qkv_mixed.size()[:-1] + (self.num_heads, -1) - qkv_mixed = qkv_mixed.view(*new_shape) - query_states, key_states, value_states = torch.split(qkv_mixed, [self.head_dim] * 3, dim=3) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if past_key_value is not None: - cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # show_tensor(key_states, False, True) - - # TODO: store following computed in cache - attn_output, attn_weights = lightning_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - **kwargs, - ) - - attn_output = rearrange(attn_output, "b h n d -> b n (h d)") - attn_output = self.norm(attn_output) - attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output - attn_output = self.out_proj(attn_output) - - # ic(self.layer_idx) - # show_tensor(attn_output, False, True) - - return attn_output, attn_weights - - class MiniMaxText01BlockSparseTop2MLP(nn.Module): def __init__(self, config: MiniMaxText01Config): super().__init__() @@ -498,26 +475,15 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): self.input_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MiniMaxText01RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # TODO: add each of these to config - self.residual_post_norm = getattr(config, "residual_post_norm", False) - self.layernorm_attention_alpha = getattr(config, "layernorm_attention_alpha", 1) - self.layernorm_attention_beta = getattr(config, "layernorm_attention_beta", 1) - self.layernorm_lightning_attention_alpha = getattr(config, "layernorm_lightning_attention_alpha", 1) - self.layernorm_lightning_attention_beta = getattr(config, "layernorm_lightning_attention_beta", 1) - self.layernorm_mlp_alpha = getattr(config, "layernorm_mlp_alpha", 1) - self.layernorm_mlp_beta = getattr(config, "layernorm_mlp_beta", 1) - - # TODO: remove these self.layer_idx = layer_idx - self.residual_post_norm = True - self.layernorm_attention_alpha = 3.5565588200778455 - self.layernorm_attention_beta = 1.0 - self.layernorm_lightning_attention_alpha = 3.5565588200778455 - self.layernorm_lightning_attention_beta = 1.0 - self.layernorm_mlp_alpha = 3.5565588200778455 - self.layernorm_mlp_beta = 1.0 - - # TODO: attn_type_list to config + self.residual_post_norm = config.residual_post_norm + self.layernorm_attention_alpha = config.layernorm_attention_alpha + self.layernorm_attention_beta = config.layernorm_attention_beta + self.layernorm_lightning_attention_alpha = config.layernorm_lightning_attention_alpha + self.layernorm_lightning_attention_beta = config.layernorm_lightning_attention_beta + self.layernorm_mlp_alpha = config.layernorm_mlp_alpha + self.layernorm_mlp_beta = config.layernorm_mlp_beta + if config.attn_type_list[layer_idx] == 0: self.self_attn = MiniMaxText01LightningAttention(config, layer_idx) self.layernorm_alpha = self.layernorm_lightning_attention_alpha @@ -527,8 +493,6 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): self.layernorm_alpha = self.layernorm_attention_alpha self.layernorm_beta = self.layernorm_attention_beta - # TODO: shared_moe - def forward( self, hidden_states: torch.Tensor, @@ -563,9 +527,6 @@ def forward( Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ - # print() - # ic(self.layer_idx) - # show_tensor(hidden_states, False, True) residual = hidden_states @@ -603,8 +564,6 @@ def forward( if output_router_logits: outputs += (router_logits,) - # show_tensor(hidden_states, False, True) - return outputs @@ -633,7 +592,7 @@ def _dynamic_frequency_update(self, position_ids, device): 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth + if seq_len > self.max_seq_len_cached: # growth_dynamic_frequency_update inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len @@ -861,6 +820,7 @@ def forward( ) use_cache = False + # TODO: raise exception here? if use_cache and past_key_values is None: past_key_values = DynamicCache() diff --git a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py index ce00c4dd1081..502a42344d44 100644 --- a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py @@ -22,7 +22,6 @@ from typing import Callable, List, Optional, Tuple, Union import math -from einops import rearrange import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -167,6 +166,15 @@ def __init__( output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.0, + attn_type_list = None, + block_size = 256, + residual_post_norm=False, + layernorm_attention_alpha=1, + layernorm_attention_beta=1, + layernorm_lightning_attention_alpha=1, + layernorm_lightning_attention_beta=1, + layernorm_mlp_alpha=1, + layernorm_mlp_beta=1, **kwargs, ): self.vocab_size = vocab_size @@ -176,11 +184,6 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range @@ -195,6 +198,32 @@ def __init__( self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef self.router_jitter_noise = router_jitter_noise + + # use softmax-attention after `interval` lightning-attentions + interval = num_hidden_layers // 10 + self.attn_type_list = ( + [1 if i%interval==interval-1 else 0 for i in range(num_hidden_layers)] + if attn_type_list is None else attn_type_list + ) + + self.block_size = block_size + self.residual_post_norm = residual_post_norm + self.layernorm_attention_alpha = layernorm_attention_alpha + self.layernorm_attention_beta = layernorm_attention_beta + self.layernorm_lightning_attention_alpha = layernorm_lightning_attention_alpha + self.layernorm_lightning_attention_beta = layernorm_lightning_attention_beta + self.layernorm_mlp_alpha = layernorm_mlp_alpha + self.layernorm_mlp_beta = layernorm_mlp_beta + + # TODO: move these to saved config + self.residual_post_norm = True + self.layernorm_attention_alpha = 3.5565588200778455 + self.layernorm_attention_beta = 1.0 + self.layernorm_lightning_attention_alpha = 3.5565588200778455 + self.layernorm_lightning_attention_beta = 1.0 + self.layernorm_mlp_alpha = 3.5565588200778455 + self.layernorm_mlp_beta = 1.0 + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, @@ -203,112 +232,51 @@ def __init__( **kwargs, ) -# ---------------------- -# TODO: see if rotary_emb works at Model level rather than attention level -# checked: it works +class MiniMaxText01LightningAttentionDecay(nn.Module): + def __init__(self, config: MiniMaxText01Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_heads = config.num_attention_heads + self.num_hidden_layers = config.num_hidden_layers + self.block_size = config.block_size -# TODO -class MiniMaxText01Attention(MixtralAttention): - pass - - -def get_slopes(head_dim): - equ = lambda x: 1 / (2 ** (8/x)) - - log2 = math.log2(head_dim) - if log2.is_integer(): - return [equ(head_dim) ** i for i in range(1, head_dim+1)] + def forward(self, x, seq_len): + num_blocks = (seq_len + self.block_size - 1) // self.block_size + padding = num_blocks * self.block_size - seq_len - lower_bound = 2 ** math.floor(log2) - upper_bound = 2 ** math.ceil(log2) + num_heads_range = torch.arange(self.num_heads).to(x) + 1 + block_size_range = torch.arange(self.block_size).to(x) + 1 - lower_bound_slopes = get_slopes(lower_bound) - upper_bound_slopes = get_slopes(upper_bound) - slopes = lower_bound_slopes + upper_bound_slopes[::2][:head_dim-lower_bound] + slope_rate = ( 1 / (2 ** (8/self.num_heads)) ) ** num_heads_range + slope_rate *= 1 - self.layer_idx / (self.num_hidden_layers - 1) + 1e-5 # check small addition + slope_rate = slope_rate[:, None, None] - return slopes + query_decay = torch.exp(-slope_rate * block_size_range[:, None]) + query_decay = query_decay[:, None, :, :] + key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None])) + key_decay = key_decay[:, None, :, :] + key_decay = key_decay.repeat(1, num_blocks, 1, 1) + key_decay[:, -1, :self.block_size-padding] = key_decay[:, -1, padding:] -# TODO: clean and refactor -# TODO: lightning + eager = attention_mask is not None = fails -def lightning_attention_forward( - module, - query_states, - key_states, - value_states, - attention_mask, - **kwargs, - ): + diagonal_decay = block_size_range[:, None] - block_size_range[None, :] + diagonal_decay = slope_rate * diagonal_decay[None, :, :] + diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf")) + diagonal_decay = torch.exp(diagonal_decay) + diagonal_decay = diagonal_decay[:, None, :, :] - batch_size, hidden_size, seq_len, head_dim = query_states.shape - batch_size, hidden_size, seq_len, kv_head_dim = value_states.shape - - BLOCK = 256 - num_blocks = (seq_len + BLOCK - 1) // BLOCK - - if attention_mask is not None: - value_states = value_states.masked_fill((1 - attention_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0) - - slope_rate = get_slopes(head_dim) - slope_rate = torch.tensor(slope_rate, device=query_states.device, dtype=torch.float32) - # TODO: check for a different batch size - slope_rate = slope_rate.unsqueeze(1).unsqueeze(1) - slope_rate *= 1 - module.layer_idx / (module.num_hidden_layers - 1) + 1e-5 - - array = torch.arange(BLOCK).to(query_states) + 1 - query_states_decay = torch.exp(-slope_rate * array.reshape(-1, 1)) - key_states_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1))) - index = array[:, None] - array[None, :] - s_index = slope_rate * index[ - None, - None, - ] - s_index = torch.where(index >= 0, -s_index, float("-inf")) - diag_decay = torch.exp(s_index) - - # TODO: remove unused kv - kv = torch.zeros(batch_size, hidden_size, head_dim, kv_head_dim).to(torch.float32).to(query_states.device) - output = torch.empty( - (batch_size, hidden_size, seq_len, kv_head_dim), - dtype=query_states.dtype, - device=query_states.device - ) - - for i in range(num_blocks): - si = i * BLOCK - ei = min(si + BLOCK, seq_len) - m = ei - si - query_states_i = query_states[:, :, si:ei].contiguous() - key_states_i = key_states[:, :, si:ei].contiguous() - value_states_i = value_states[:, :, si:ei].contiguous() - qkv_none_diag = torch.matmul( - query_states_i * query_states_decay[:, :m], kv - ).to(torch.float32) - - # diag - qk = torch.matmul( - query_states_i, - key_states_i.transpose(-1, -2) - ).to(torch.float32) * diag_decay[:, :, :m, :m] - qkv_diag = torch.matmul(qk, value_states_i.to(torch.float32)) - block_decay = torch.exp(-slope_rate * m) - output[:, :, si:ei] = qkv_none_diag + qkv_diag - kv = ( - block_decay * kv - + - torch.matmul( - ( - key_states_i * key_states_decay[:, -m:] - ).transpose(-1, -2).to(value_states_i.dtype), - value_states_i - ) - ) + block_lengths = torch.cat(( + torch.full((num_blocks-1,), self.block_size), + torch.tensor([self.block_size-padding]) + )).to(x) + block_decay = torch.exp(-slope_rate[:, None, :, :] * block_lengths[:, None, None]) - return output, None + return key_decay, query_decay, diagonal_decay, block_decay -# TODO class MiniMaxText01LightningAttention(nn.Module): def __init__(self, config: MiniMaxText01Config, layer_idx: int): super().__init__() @@ -317,16 +285,14 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_heads = config.num_attention_heads self.num_hidden_layers = config.num_hidden_layers + self.block_size = config.block_size self.act_fn = ACT2FN[config.hidden_act] self.norm = MixtralRMSNorm(self.head_dim * self.num_heads) self.qkv_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim * 3, bias=False) - # TODO: separate q,k,v - # self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - # self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - # self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) self.output_gate = nn.Linear(config.hidden_size, self.num_heads * self.num_heads, bias=False) + self.decay_factors = MiniMaxText01LightningAttentionDecay(config, layer_idx) def forward( self, @@ -337,73 +303,88 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - # TODO: separate q,k,v - # query_states = self.act_fn(self.q_proj(hidden_states)).view(hidden_shape).transpose(1, 2) - # key_states = self.act_fn(self.k_proj(hidden_states)).view(hidden_shape).transpose(1, 2) - # value_states = self.act_fn(self.v_proj(hidden_states)).view(hidden_shape).transpose(1, 2) - - qkv_mixed = self.act_fn(self.qkv_proj(hidden_states)) - new_shape = qkv_mixed.size()[:-1] + (self.num_heads, -1) - qkv_mixed = qkv_mixed.view(*new_shape) - query_states, key_states, value_states = torch.split(qkv_mixed, [self.head_dim] * 3, dim=3) + batch_size, seq_len, hidden_size = hidden_states.shape + num_blocks = (seq_len + self.block_size - 1) // self.block_size + padding = num_blocks * self.block_size - seq_len + + qkv_states = self.act_fn(self.qkv_proj(hidden_states)) + qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_heads, 3 * self.head_dim) + + query_states, key_states, value_states = torch.split(qkv_states, [self.head_dim] * 3, dim=3) + query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - if past_key_value is not None: - cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # TODO: apply attention_mask - # show_tensor(key_states, False, True) + query_states = F.pad(query_states, (0, 0, 0, padding)) + key_states = F.pad(key_states, (0, 0, 0, padding)) + value_states = F.pad(value_states, (0, 0, 0, padding)) - # TODO: store following computed in cache - attn_output, attn_weights = lightning_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - **kwargs, - ) + query_states = query_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) + key_states = key_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) + value_states = value_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) + + # TODO: get from past_key_value[layer_idx] + next_cache = torch.zeros(batch_size, self.num_heads, 1, self.head_dim, self.head_dim).to(value_states) - attn_output = rearrange(attn_output, "b h n d -> b n (h d)") + # get decay factors + key_decay, query_decay, diagonal_decay, block_decay = self.decay_factors(query_states, seq_len) + + # intra: ( Q @ K.T ) @ V -> QK * V + attn_weights_intra = torch.matmul(query_states, key_states.transpose(-1, -2)) + attn_output_intra = torch.matmul(attn_weights_intra * diagonal_decay, value_states) + + # inter: Q @ ( K.T @ V ) -> Q * KV + attn_weights_inter = torch.matmul((key_states * key_decay).transpose(-1, -2), value_states) + attn_weights_inter = torch.cat([next_cache, attn_weights_inter], dim=2) + for i in range(num_blocks): + attn_weights_inter[:, :, i+1, :, :] += attn_weights_inter[:, :, i, :, :] * block_decay[:, i, :, :] + next_cache = attn_weights_inter[:, :, -1, :, :] + attn_weights_inter = attn_weights_inter[:, :, :-1, :, :] + attn_output_inter = torch.matmul(query_states * query_decay, attn_weights_inter) + + # inter + intra + attn_output = attn_output_inter + attn_output_intra + attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len+padding, self.head_dim) + attn_output = attn_output[:, :, :seq_len, :] + + # final output projection + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, seq_len, self.num_heads * self.head_dim) attn_output = self.norm(attn_output) attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output attn_output = self.out_proj(attn_output) - # ic(self.layer_idx) - # show_tensor(attn_output, False, True) + # TODO: put to past_key_value[layer_idx] + next_cache + + # TODO: remove these + print() + print(self.layer_idx) + print(next_cache) + + return attn_output, None - return attn_output, attn_weights + +class MiniMaxText01Attention(MixtralAttention): + pass class MiniMaxText01DecoderLayer(MixtralDecoderLayer): def __init__(self, config: MiniMaxText01Config, layer_idx: int): super().__init__(config, layer_idx) - # TODO: add each of these to config - self.residual_post_norm = getattr(config, "residual_post_norm", False) - self.layernorm_attention_alpha = getattr(config, "layernorm_attention_alpha", 1) - self.layernorm_attention_beta = getattr(config, "layernorm_attention_beta", 1) - self.layernorm_lightning_attention_alpha = getattr(config, "layernorm_lightning_attention_alpha", 1) - self.layernorm_lightning_attention_beta = getattr(config, "layernorm_lightning_attention_beta", 1) - self.layernorm_mlp_alpha = getattr(config, "layernorm_mlp_alpha", 1) - self.layernorm_mlp_beta = getattr(config, "layernorm_mlp_beta", 1) - - # TODO: remove these self.layer_idx = layer_idx - self.residual_post_norm = True - self.layernorm_attention_alpha = 3.5565588200778455 - self.layernorm_attention_beta = 1.0 - self.layernorm_lightning_attention_alpha = 3.5565588200778455 - self.layernorm_lightning_attention_beta = 1.0 - self.layernorm_mlp_alpha = 3.5565588200778455 - self.layernorm_mlp_beta = 1.0 + self.residual_post_norm = config.residual_post_norm + self.layernorm_attention_alpha = config.layernorm_attention_alpha + self.layernorm_attention_beta = config.layernorm_attention_beta + self.layernorm_lightning_attention_alpha = config.layernorm_lightning_attention_alpha + self.layernorm_lightning_attention_beta = config.layernorm_lightning_attention_beta + self.layernorm_mlp_alpha = config.layernorm_mlp_alpha + self.layernorm_mlp_beta = config.layernorm_mlp_beta - # TODO: attn_type_list to config if config.attn_type_list[layer_idx] == 0: self.self_attn = MiniMaxText01LightningAttention(config, layer_idx) self.layernorm_alpha = self.layernorm_lightning_attention_alpha @@ -413,8 +394,6 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): self.layernorm_alpha = self.layernorm_attention_alpha self.layernorm_beta = self.layernorm_attention_beta - # TODO: shared_moe - def forward( self, hidden_states: torch.Tensor, @@ -449,9 +428,6 @@ def forward( Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ - # print() - # ic(self.layer_idx) - # show_tensor(hidden_states, False, True) residual = hidden_states @@ -489,8 +465,6 @@ def forward( if output_router_logits: outputs += (router_logits,) - # show_tensor(hidden_states, False, True) - return outputs From c54f8045ec7ca36e01e0f469a53f6e7177bbfb40 Mon Sep 17 00:00:00 2001 From: geetu040 Date: Fri, 24 Jan 2025 20:13:31 +0500 Subject: [PATCH 03/11] put minimax_text_01 in other files --- docs/source/ar/_toctree.yml | 2 + docs/source/ar/conversations.md | 2 +- docs/source/ar/index.md | 1 + docs/source/ar/trainer.md | 2 +- docs/source/en/_toctree.yml | 2 + docs/source/en/conversations.md | 2 +- docs/source/en/index.md | 1 + docs/source/en/perf_infer_gpu_one.md | 3 ++ docs/source/ko/conversations.md | 2 +- docs/source/zh/index.md | 1 + src/transformers/__init__.py | 20 +++++++++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 5 +++ .../models/auto/tokenization_auto.py | 7 ++++ src/transformers/utils/dummy_pt_objects.py | 42 +++++++++++++++++++ src/transformers/utils/fx.py | 1 + .../modular/test_conversion_order.py | 2 + tests/test_modeling_common.py | 2 +- utils/not_doctested.txt | 3 ++ 20 files changed, 98 insertions(+), 5 deletions(-) diff --git a/docs/source/ar/_toctree.yml b/docs/source/ar/_toctree.yml index 30e247eb54e1..f09cba37265b 100644 --- a/docs/source/ar/_toctree.yml +++ b/docs/source/ar/_toctree.yml @@ -454,6 +454,8 @@ # title: Mistral # - local: model_doc/mixtral # title: Mixtral +# - local: model_doc/minimax_text_01 +# title: MiniMaxText01 # - local: model_doc/mluke # title: mLUKE # - local: model_doc/mobilebert diff --git a/docs/source/ar/conversations.md b/docs/source/ar/conversations.md index 00e6fe814ea0..dffb1836de20 100644 --- a/docs/source/ar/conversations.md +++ b/docs/source/ar/conversations.md @@ -201,4 +201,4 @@ pipe = pipeline("text-generation", "meta-llama/Meta-Llama-3-8B-Instruct", device لذلك، إذا كنت تريد تحسين سرعة توليد النص، فإن الحل الأسهل هو إما تقليل حجم النموذج في الذاكرة (عادةً عن طريق التكميم)، أو الحصول على عتاد بسرعة أكبر في الذاكرة. بالنسبة للمستخدمين المتقدمين، هناك عدة تقنيات أخرى للتغلب على هذه القيود. الأكثر شيوعًا هي المتغيرات على [التوليد بمساعدة](https://huggingface.co/blog/assisted-generation)، المعروف أيضًا باسم "العينات التخمينية (speculative sampling)". تحاول هذه التقنيات تخمين عدة رموز مستقبلية في وقت واحد، غالبًا باستخدام نموذج "مسودة (draft model)" أصغر، ثم تأكيد هذه التوليدات باستخدام نموذج الدردشة. إذا تم التحقق من صحة التخمينات بواسطة نموذج الدردشة، فيمكن إنشاء أكثر من رمز واحد لكل تمرير للأمام، مما يخفف بشكل كبير من القيود المتعلقة بالسعة ويحسن سرعة التوليد. -أخيرًا، يجب أن نلاحظ أيضًا تأثير نماذج "مزيج الخبراء" "Mixture of Experts" (MoE) هنا. العديد من نماذج المحادثة الشهيرة، مثل Mixtral وQwen-MoE وDBRX، هي نماذج MoE. في هذه النماذج، لا تكون كل معلمة نشطة لكل رمز يتم إنشاؤه. ونتيجة لذلك، فإن نماذج MoE لديها عمومًا متطلبات ذاكرة أقل بكثير، على الرغم من أن حجمها الإجمالي يمكن أن يكون كبيرًا جدًا. لذلك يمكن أن تكون أسرع عدة مرات من نموذج "كثيف" عادي بنفس الحجم. ومع ذلك، فإن التقنيات مثل التوليد المساعد غير فعالة بشكل عام لهذه النماذج لأن المزيد من المعلمات ستصبح نشطة مع كل رمز جديد يتم التكهن به، والذي سيبطل فوائد السعة والسرعة التي توفرها بنية MoE. \ No newline at end of file +أخيرًا، يجب أن نلاحظ أيضًا تأثير نماذج "مزيج الخبراء" "Mixture of Experts" (MoE) هنا. العديد من نماذج المحادثة الشهيرة، مثل Mixtral وMiniMaxText01وQwen-MoE وDBRX، هي نماذج MoE. في هذه النماذج، لا تكون كل معلمة نشطة لكل رمز يتم إنشاؤه. ونتيجة لذلك، فإن نماذج MoE لديها عمومًا متطلبات ذاكرة أقل بكثير، على الرغم من أن حجمها الإجمالي يمكن أن يكون كبيرًا جدًا. لذلك يمكن أن تكون أسرع عدة مرات من نموذج "كثيف" عادي بنفس الحجم. ومع ذلك، فإن التقنيات مثل التوليد المساعد غير فعالة بشكل عام لهذه النماذج لأن المزيد من المعلمات ستصبح نشطة مع كل رمز جديد يتم التكهن به، والذي سيبطل فوائد السعة والسرعة التي توفرها بنية MoE. \ No newline at end of file diff --git a/docs/source/ar/index.md b/docs/source/ar/index.md index c37dbd1c6d9f..299245d789f3 100644 --- a/docs/source/ar/index.md +++ b/docs/source/ar/index.md @@ -196,6 +196,7 @@ | [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ | | [Mistral](model_doc/mistral) | ✅ | ✅ | ✅ | | [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ | +| [MiniMaxText01](model_doc/minimax_text_01) | ✅ | ❌ | ❌ | | [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ | | [MMS](model_doc/mms) | ✅ | ✅ | ✅ | | [MobileBERT](model_doc/mobilebert) | ✅ | ✅ | ❌ | diff --git a/docs/source/ar/trainer.md b/docs/source/ar/trainer.md index 7da7cbf4e171..ddf62a5fe0de 100644 --- a/docs/source/ar/trainer.md +++ b/docs/source/ar/trainer.md @@ -265,7 +265,7 @@ training_args = TrainingArguments( ) ``` -تدعم النواة معماريات نماذج Llama و Gemma و Mistral و Mixtral. يُمكن العثور على أحدث قائمة بالنمائج المدعومة [هنا](https://github.com/linkedin/Liger-Kernel). عندما يتم تعيين `use_liger_kernel` إلى `True`، سيتم تصحيح الطبقات المُقابلة في النموذج الأصلي باستخدام تطبيق Liger الفعال، لذلك لا تحتاج إلى فعل أي شيء إضافي بخلاف تعيين قيمة المعامل. +تدعم النواة معماريات نماذج Llama و Gemma و Mistral و Mixtralو MiniMaxText01. يُمكن العثور على أحدث قائمة بالنمائج المدعومة [هنا](https://github.com/linkedin/Liger-Kernel). عندما يتم تعيين `use_liger_kernel` إلى `True`، سيتم تصحيح الطبقات المُقابلة في النموذج الأصلي باستخدام تطبيق Liger الفعال، لذلك لا تحتاج إلى فعل أي شيء إضافي بخلاف تعيين قيمة المعامل. ## المُحسِّنات يمكنك اختيار مُحسِّن مدمج للتدريب باستخدام: diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 3e8bcd9ece1f..dfca7a06d52e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -498,6 +498,8 @@ title: Mistral - local: model_doc/mixtral title: Mixtral + - local: model_doc/minimax_text_01 + title: MiniMaxText01 - local: model_doc/mluke title: mLUKE - local: model_doc/mobilebert diff --git a/docs/source/en/conversations.md b/docs/source/en/conversations.md index a48c046b4949..21fae659ac58 100644 --- a/docs/source/en/conversations.md +++ b/docs/source/en/conversations.md @@ -281,7 +281,7 @@ confirm these generations with the chat model. If the guesses are validated by t be generated per forward pass, which greatly alleviates the bandwidth bottleneck and improves generation speed. Finally, we should also note the impact of "Mixture of Experts" (MoE) models here. Several popular chat models, -such as Mixtral, Qwen-MoE and DBRX, are MoE models. In these models, not every parameter is active for every token generated. +such as Mixtral, MiniMaxText01, Qwen-MoE and DBRX, are MoE models. In these models, not every parameter is active for every token generated. As a result, MoE models generally have much lower memory bandwidth requirements, even though their total size can be quite large. They can therefore be several times faster than a normal "dense" model of the same size. However, techniques like assisted generation are generally ineffective for these models because more parameters will become diff --git a/docs/source/en/index.md b/docs/source/en/index.md index ace8f76f7d01..8615c0eef569 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -228,6 +228,7 @@ Flax), PyTorch, and/or TensorFlow. | [Mimi](model_doc/mimi) | ✅ | ❌ | ❌ | | [Mistral](model_doc/mistral) | ✅ | ✅ | ✅ | | [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ | +| [MiniMaxText01](model_doc/minimax_text_01) | ✅ | ❌ | ❌ | | [Mllama](model_doc/mllama) | ✅ | ❌ | ❌ | | [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ | | [MMS](model_doc/mms) | ✅ | ✅ | ✅ | diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index d9bdf6f6e484..3fff55e25059 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -77,6 +77,7 @@ FlashAttention-2 is currently supported for the following architectures: * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [MiniMaxText01](https://huggingface.co/docs/transformers/model_doc/minimax_text_01#transformers.MiniMaxText01Model) * [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert) * [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) @@ -274,6 +275,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [MiniMaxText01](https://huggingface.co/docs/transformers/model_doc/minimax_text_01#transformers.MiniMaxText01Model) * [ModernBert](https://huggingface.co/docs/transformers/model_doc/modernbert#transformers.ModernBert) * [Moshi](https://huggingface.co/docs/transformers/model_doc/moshi#transformers.MoshiModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) @@ -292,6 +294,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Moonshine](https://huggingface.co/docs/transformers/model_doc/moonshine#transformers.MoonshineModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [MiniMaxText01](https://huggingface.co/docs/transformers/model_doc/minimax_text_01#transformers.MiniMaxText01Model) * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) * [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) diff --git a/docs/source/ko/conversations.md b/docs/source/ko/conversations.md index 920cb1387860..daae2fb69ae6 100644 --- a/docs/source/ko/conversations.md +++ b/docs/source/ko/conversations.md @@ -296,7 +296,7 @@ pipe = pipeline("text-generation", "meta-llama/Meta-Llama-3-8B-Instruct", device 병목 현상이 크게 줄어들고 생성 속도가 빨라집니다. 마지막으로, "Mixture of Experts" (MoE) 모델에 대해서도 짚고 넘어가 보도록 합니다. -Mixtral, Qwen-MoE, DBRX와 같은 인기 있는 채팅 모델이 바로 MoE 모델입니다. +Mixtral, MiniMaxText01, Qwen-MoE, DBRX와 같은 인기 있는 채팅 모델이 바로 MoE 모델입니다. 이 모델들은 토큰을 생성할 때 모든 파라미터가 사용되지 않습니다. 이로 인해 MoE 모델은 전체 크기가 상당히 클 수 있지만, 차지하는 메모리 대역폭은 낮은 편입니다. diff --git a/docs/source/zh/index.md b/docs/source/zh/index.md index 3750e506b0ea..3f032cedffd4 100644 --- a/docs/source/zh/index.md +++ b/docs/source/zh/index.md @@ -191,6 +191,7 @@ rendered properly in your Markdown viewer. | [MGP-STR](../en/model_doc/mgp-str) | ✅ | ❌ | ❌ | | [Mistral](../en/model_doc/mistral) | ✅ | ❌ | ✅ | | [Mixtral](../en/model_doc/mixtral) | ✅ | ❌ | ❌ | +| [MiniMaxText01](../en/model_doc/minimax_text_01) | ✅ | ❌ | ❌ | | [mLUKE](../en/model_doc/mluke) | ✅ | ❌ | ❌ | | [MMS](../en/model_doc/mms) | ✅ | ✅ | ✅ | | [MobileBERT](../en/model_doc/mobilebert) | ✅ | ✅ | ❌ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 64b5f4b52c40..1566a02dc2cb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -602,6 +602,7 @@ "models.mimi": ["MimiConfig"], "models.mistral": ["MistralConfig"], "models.mixtral": ["MixtralConfig"], + "models.minimax_text_01": ["MiniMaxText01Config"], "models.mllama": [ "MllamaConfig", "MllamaProcessor", @@ -2868,6 +2869,16 @@ "MixtralPreTrainedModel", ] ) + _import_structure["models.minimax_text_01"].extend( + [ + "MiniMaxText01ForCausalLM", + "MiniMaxText01ForQuestionAnswering", + "MiniMaxText01ForSequenceClassification", + "MiniMaxText01ForTokenClassification", + "MiniMaxText01Model", + "MiniMaxText01PreTrainedModel", + ] + ) _import_structure["models.mllama"].extend( [ "MllamaForCausalLM", @@ -5657,6 +5668,7 @@ ) from .models.mistral import MistralConfig from .models.mixtral import MixtralConfig + from .models.minimax_text_01 import MiniMaxText01Config from .models.mllama import ( MllamaConfig, MllamaProcessor, @@ -7661,6 +7673,14 @@ MixtralModel, MixtralPreTrainedModel, ) + from .models.minimax_text_01 import ( + MiniMaxText01Model, + MiniMaxText01ForCausalLM, + MiniMaxText01ForSequenceClassification, + MiniMaxText01ForTokenClassification, + MiniMaxText01ForQuestionAnswering, + MiniMaxText01PreTrainedModel, + ) from .models.mllama import ( MllamaForCausalLM, MllamaForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 69c3c26cfa2a..086d23800e48 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -164,6 +164,7 @@ mimi, mistral, mixtral, + minimax_text_01, mllama, mluke, mobilebert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index e93c201bb0cc..c540c45f6d12 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -185,6 +185,7 @@ ("mimi", "MimiConfig"), ("mistral", "MistralConfig"), ("mixtral", "MixtralConfig"), + ("minimax_text_01", "MiniMaxText01Config"), ("mllama", "MllamaConfig"), ("mobilebert", "MobileBertConfig"), ("mobilenet_v1", "MobileNetV1Config"), @@ -516,6 +517,7 @@ ("mimi", "Mimi"), ("mistral", "Mistral"), ("mixtral", "Mixtral"), + ("minimax_text_01", "MiniMaxText01"), ("mllama", "Mllama"), ("mluke", "mLUKE"), ("mms", "MMS"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index fc24fd4dcfa3..0b8be71c7120 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -174,6 +174,7 @@ ("mimi", "MimiModel"), ("mistral", "MistralModel"), ("mixtral", "MixtralModel"), + ("minimax_text_01", "MiniMaxText01Model"), ("mobilebert", "MobileBertModel"), ("mobilenet_v1", "MobileNetV1Model"), ("mobilenet_v2", "MobileNetV2Model"), @@ -531,6 +532,7 @@ ("megatron-bert", "MegatronBertForCausalLM"), ("mistral", "MistralForCausalLM"), ("mixtral", "MixtralForCausalLM"), + ("minimax_text_01", "MiniMaxText01ForCausalLM"), ("mllama", "MllamaForCausalLM"), ("moshi", "MoshiForCausalLM"), ("mpt", "MptForCausalLM"), @@ -1010,6 +1012,7 @@ ("megatron-bert", "MegatronBertForSequenceClassification"), ("mistral", "MistralForSequenceClassification"), ("mixtral", "MixtralForSequenceClassification"), + ("minimax_text_01", "MiniMaxText01ForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"), ("modernbert", "ModernBertForSequenceClassification"), ("mpnet", "MPNetForSequenceClassification"), @@ -1098,6 +1101,7 @@ ("megatron-bert", "MegatronBertForQuestionAnswering"), ("mistral", "MistralForQuestionAnswering"), ("mixtral", "MixtralForQuestionAnswering"), + ("minimax_text_01", "MiniMaxText01ForQuestionAnswering"), ("mobilebert", "MobileBertForQuestionAnswering"), ("mpnet", "MPNetForQuestionAnswering"), ("mpt", "MptForQuestionAnswering"), @@ -1200,6 +1204,7 @@ ("megatron-bert", "MegatronBertForTokenClassification"), ("mistral", "MistralForTokenClassification"), ("mixtral", "MixtralForTokenClassification"), + ("minimax_text_01", "MiniMaxText01ForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), ("modernbert", "ModernBertForTokenClassification"), ("mpnet", "MPNetForTokenClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 9ce9edd06cb2..8e2ae7452dcf 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -319,6 +319,13 @@ "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "minimax_text_01", + ( + "GPT2Tokenizer" if is_sentencepiece_available() else None, + "GPT2TokenizerFast" if is_tokenizers_available() else None, + ), + ), ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2574af7e8a41..799467cd0845 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6315,6 +6315,48 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class MiniMaxText01ForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MiniMaxText01ForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MiniMaxText01ForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MiniMaxText01ForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MiniMaxText01Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MiniMaxText01PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MllamaForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 45fa3d9ca68c..c70b45a4b13a 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -150,6 +150,7 @@ def _generate_supported_model_class_names( "megatron-bert", "mistral", "mixtral", + "minimax_text_01", "mobilebert", "mt5", "nezha", diff --git a/tests/repo_utils/modular/test_conversion_order.py b/tests/repo_utils/modular/test_conversion_order.py index f5e133ce1fea..1dbd18f71a41 100644 --- a/tests/repo_utils/modular/test_conversion_order.py +++ b/tests/repo_utils/modular/test_conversion_order.py @@ -19,6 +19,7 @@ os.path.join(MODEL_ROOT, "granite", "modular_granite.py"), os.path.join(MODEL_ROOT, "gemma2", "modular_gemma2.py"), os.path.join(MODEL_ROOT, "mixtral", "modular_mixtral.py"), + os.path.join(MODEL_ROOT, "minimax_text_01", "modular_minimax_text_01.py"), os.path.join(MODEL_ROOT, "olmo", "modular_olmo.py"), os.path.join(MODEL_ROOT, "rt_detr", "modular_rt_detr.py"), os.path.join(MODEL_ROOT, "qwen2", "modular_qwen2.py"), @@ -53,6 +54,7 @@ def test_conversion_order(self): model_priority_list = [file.rsplit("modular_")[-1].replace(".py", "") for file in priority_list] # These are based on what the current library order should be (as of 09/01/2025) + self.assertTrue(appear_after("minimax_text_01", "mixtral", model_priority_list)) self.assertTrue(appear_after("mixtral", "mistral", model_priority_list)) self.assertTrue(appear_after("gemma2", "gemma", model_priority_list)) self.assertTrue(appear_after("starcoder2", "mistral", model_priority_list)) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cf259fabe302..3ce9a84adb3d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4327,7 +4327,7 @@ def test_sdpa_matches_eager_sliding_window(self): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") - WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"] + WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "minimax_text_01", "qwen2", "qwen_moe", "starcoder2"] if len(self.all_generative_model_classes) == 0: self.skipTest(f"No generative model classes for {self.__class__.__name__}") diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index 0a36fcbd8a5f..0a9c635f1973 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -165,6 +165,7 @@ docs/source/en/model_doc/megatron_gpt2.md docs/source/en/model_doc/mgp-str.md docs/source/en/model_doc/mistral.md docs/source/en/model_doc/mixtral.md +docs/source/en/model_doc/minimax_text_01.md docs/source/en/model_doc/mluke.md docs/source/en/model_doc/mms.md docs/source/en/model_doc/mobilebert.md @@ -675,6 +676,8 @@ src/transformers/models/mistral/configuration_mistral.py src/transformers/models/mistral/modeling_mistral.py src/transformers/models/mixtral/configuration_mixtral.py src/transformers/models/mixtral/modeling_mixtral.py +src/transformers/models/minimax_text_01/configuration_minimax_text_01.py +src/transformers/models/minimax_text_01/modeling_minimax_text_01.py src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py src/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py From d8d3c409d89e335c98a8cd36f47304a76eac7493 Mon Sep 17 00:00:00 2001 From: geetu040 Date: Mon, 27 Jan 2025 07:48:21 +0500 Subject: [PATCH 04/11] use latest __init__ standards and auto-generate modular --- .../models/minimax_text_01/__init__.py | 51 +++---------------- .../modeling_minimax_text_01.py | 51 +++++-------------- 2 files changed, 19 insertions(+), 83 deletions(-) diff --git a/src/transformers/models/minimax_text_01/__init__.py b/src/transformers/models/minimax_text_01/__init__.py index 1d65a515cf17..2f92703446d5 100644 --- a/src/transformers/models/minimax_text_01/__init__.py +++ b/src/transformers/models/minimax_text_01/__init__.py @@ -15,54 +15,15 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - is_torch_available, -) - - -_import_structure = { - "configuration_minimax_text_01": ["MiniMaxText01Config"], -} - - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_minimax_text_01"] = [ - "MiniMaxText01ForCausalLM", - "MiniMaxText01ForQuestionAnswering", - "MiniMaxText01Model", - "MiniMaxText01PreTrainedModel", - "MiniMaxText01ForSequenceClassification", - "MiniMaxText01ForTokenClassification", - ] +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure if TYPE_CHECKING: - from .configuration_minimax_text_01 import MiniMaxText01Config - - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_minimax_text_01 import ( - MiniMaxText01ForCausalLM, - MiniMaxText01ForQuestionAnswering, - MiniMaxText01ForSequenceClassification, - MiniMaxText01ForTokenClassification, - MiniMaxText01Model, - MiniMaxText01PreTrainedModel, - ) - - + from .configuration_minimax_text_01 import * + from .modeling_minimax_text_01 import * else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py index 9e33a18a58a3..b0d01400a6bb 100644 --- a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py @@ -259,13 +259,6 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - # print() - # ic(module.layer_idx) - # show_tensor(query, False, True) - # show_tensor(key_states, False, True) - # show_tensor(value_states, False, True) - # show_tensor(attn_weights, False, True) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) @@ -310,23 +303,11 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - # print(self.layer_idx) - # show_tensor(query_states, end=False, only_shapes=False) - # show_tensor(key_states, end=False, only_shapes=True) - # show_tensor(value_states, end=True, only_shapes=True) - - # print() - # print() - # ic(self.layer_idx) - # show_tensor(key_states, False, True) - if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # show_tensor(key_states, False, True) - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -351,10 +332,6 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - # ic(self.layer_idx) - # show_tensor(attn_output, False, True) - return attn_output, attn_weights @@ -592,7 +569,7 @@ def _dynamic_frequency_update(self, position_ids, device): 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth_dynamic_frequency_update + if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len @@ -628,7 +605,7 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -MINI_MAX_TEXT01_START_DOCSTRING = r""" +MINIMAX_TEXT_01_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) @@ -647,7 +624,7 @@ def forward(self, x, position_ids): @add_start_docstrings( "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.", - MINI_MAX_TEXT01_START_DOCSTRING, + MINIMAX_TEXT_01_START_DOCSTRING, ) class MiniMaxText01PreTrainedModel(PreTrainedModel): config_class = MiniMaxText01Config @@ -674,7 +651,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -MINI_MAX_TEXT01_INPUTS_DOCSTRING = r""" +MINIMAX_TEXT_01_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -751,7 +728,7 @@ def _init_weights(self, module): @add_start_docstrings( "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.", - MINI_MAX_TEXT01_START_DOCSTRING, + MINIMAX_TEXT_01_START_DOCSTRING, ) class MiniMaxText01Model(MiniMaxText01PreTrainedModel): """ @@ -783,7 +760,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, @@ -820,7 +797,6 @@ def forward( ) use_cache = False - # TODO: raise exception here? if use_cache and past_key_values is None: past_key_values = DynamicCache() @@ -1173,7 +1149,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1222,7 +1198,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - # ic(input_ids.shape, input_ids) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -1299,7 +1274,7 @@ def forward( padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, - MINI_MAX_TEXT01_START_DOCSTRING, + MINIMAX_TEXT_01_START_DOCSTRING, ) class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel): def __init__(self, config): @@ -1317,7 +1292,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1395,7 +1370,7 @@ def forward( The MiniMaxText01 Model transformer with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, - MINI_MAX_TEXT01_START_DOCSTRING, + MINIMAX_TEXT_01_START_DOCSTRING, ) class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel): def __init__(self, config): @@ -1420,7 +1395,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, @@ -1483,7 +1458,7 @@ def forward( The MiniMaxText01 Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). """, - MINI_MAX_TEXT01_START_DOCSTRING, + MINIMAX_TEXT_01_START_DOCSTRING, ) class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel): base_model_prefix = "model" @@ -1502,7 +1477,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, From 8d654d8ee7882aad0c2f548e0a07eee932719414 Mon Sep 17 00:00:00 2001 From: geetu040 Date: Mon, 27 Jan 2025 10:10:14 +0500 Subject: [PATCH 05/11] support attention_mask for lightning-attn --- .../modular_minimax_text_01.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py index 502a42344d44..fcbb9def8257 100644 --- a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py @@ -316,7 +316,9 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - # TODO: apply attention_mask + # apply attention_mask + if attention_mask is not None: + value_states = value_states.masked_fill((1 - attention_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0) query_states = F.pad(query_states, (0, 0, 0, padding)) key_states = F.pad(key_states, (0, 0, 0, padding)) @@ -384,8 +386,9 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): self.layernorm_lightning_attention_beta = config.layernorm_lightning_attention_beta self.layernorm_mlp_alpha = config.layernorm_mlp_alpha self.layernorm_mlp_beta = config.layernorm_mlp_beta + self.attn_type = config.attn_type_list[layer_idx] - if config.attn_type_list[layer_idx] == 0: + if self.attn_type == 0: self.self_attn = MiniMaxText01LightningAttention(config, layer_idx) self.layernorm_alpha = self.layernorm_lightning_attention_alpha self.layernorm_beta = self.layernorm_lightning_attention_beta @@ -397,7 +400,7 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, @@ -410,7 +413,7 @@ def forward( """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + attention_mask (`Tuple[torch.Tensor, torch.Tensor]`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): @@ -439,7 +442,7 @@ def forward( hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, - attention_mask=attention_mask, + attention_mask=attention_mask[self.attn_type], position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, @@ -475,6 +478,23 @@ def __init__(self, config: MiniMaxText01Config): [MiniMaxText01DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + causal_mask = super()._update_causal_mask( + attention_mask=attention_mask, + input_tensor=input_tensor, + cache_position=cache_position, + past_key_values=past_key_values, + output_attentions=output_attentions, + ) + return (attention_mask, causal_mask) + class MiniMaxText01ForCausalLM(MixtralForCausalLM): def __init__(self, config): From 5b40a5cd33ce393d0d18c354c46cbf569b3f188a Mon Sep 17 00:00:00 2001 From: geetu040 Date: Mon, 27 Jan 2025 10:26:40 +0500 Subject: [PATCH 06/11] Revert "use latest __init__ standards and auto-generate modular" This reverts commit d8d3c409d89e335c98a8cd36f47304a76eac7493. --- .../models/minimax_text_01/__init__.py | 51 ++++++++++++++++--- .../modeling_minimax_text_01.py | 51 ++++++++++++++----- 2 files changed, 83 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/minimax_text_01/__init__.py b/src/transformers/models/minimax_text_01/__init__.py index 2f92703446d5..1d65a515cf17 100644 --- a/src/transformers/models/minimax_text_01/__init__.py +++ b/src/transformers/models/minimax_text_01/__init__.py @@ -15,15 +15,54 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import _LazyModule -from ...utils.import_utils import define_import_structure +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_minimax_text_01": ["MiniMaxText01Config"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_minimax_text_01"] = [ + "MiniMaxText01ForCausalLM", + "MiniMaxText01ForQuestionAnswering", + "MiniMaxText01Model", + "MiniMaxText01PreTrainedModel", + "MiniMaxText01ForSequenceClassification", + "MiniMaxText01ForTokenClassification", + ] if TYPE_CHECKING: - from .configuration_minimax_text_01 import * - from .modeling_minimax_text_01 import * + from .configuration_minimax_text_01 import MiniMaxText01Config + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_minimax_text_01 import ( + MiniMaxText01ForCausalLM, + MiniMaxText01ForQuestionAnswering, + MiniMaxText01ForSequenceClassification, + MiniMaxText01ForTokenClassification, + MiniMaxText01Model, + MiniMaxText01PreTrainedModel, + ) + + else: import sys - _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py index b0d01400a6bb..9e33a18a58a3 100644 --- a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py @@ -259,6 +259,13 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask + # print() + # ic(module.layer_idx) + # show_tensor(query, False, True) + # show_tensor(key_states, False, True) + # show_tensor(value_states, False, True) + # show_tensor(attn_weights, False, True) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) @@ -303,11 +310,23 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # print(self.layer_idx) + # show_tensor(query_states, end=False, only_shapes=False) + # show_tensor(key_states, end=False, only_shapes=True) + # show_tensor(value_states, end=True, only_shapes=True) + + # print() + # print() + # ic(self.layer_idx) + # show_tensor(key_states, False, True) + if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # show_tensor(key_states, False, True) + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -332,6 +351,10 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) + + # ic(self.layer_idx) + # show_tensor(attn_output, False, True) + return attn_output, attn_weights @@ -569,7 +592,7 @@ def _dynamic_frequency_update(self, position_ids, device): 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth + if seq_len > self.max_seq_len_cached: # growth_dynamic_frequency_update inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len @@ -605,7 +628,7 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -MINIMAX_TEXT_01_START_DOCSTRING = r""" +MINI_MAX_TEXT01_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) @@ -624,7 +647,7 @@ def forward(self, x, position_ids): @add_start_docstrings( "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.", - MINIMAX_TEXT_01_START_DOCSTRING, + MINI_MAX_TEXT01_START_DOCSTRING, ) class MiniMaxText01PreTrainedModel(PreTrainedModel): config_class = MiniMaxText01Config @@ -651,7 +674,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -MINIMAX_TEXT_01_INPUTS_DOCSTRING = r""" +MINI_MAX_TEXT01_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -728,7 +751,7 @@ def _init_weights(self, module): @add_start_docstrings( "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.", - MINIMAX_TEXT_01_START_DOCSTRING, + MINI_MAX_TEXT01_START_DOCSTRING, ) class MiniMaxText01Model(MiniMaxText01PreTrainedModel): """ @@ -760,7 +783,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, @@ -797,6 +820,7 @@ def forward( ) use_cache = False + # TODO: raise exception here? if use_cache and past_key_values is None: past_key_values = DynamicCache() @@ -1149,7 +1173,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1198,6 +1222,7 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" + # ic(input_ids.shape, input_ids) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -1274,7 +1299,7 @@ def forward( padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, - MINIMAX_TEXT_01_START_DOCSTRING, + MINI_MAX_TEXT01_START_DOCSTRING, ) class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel): def __init__(self, config): @@ -1292,7 +1317,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1370,7 +1395,7 @@ def forward( The MiniMaxText01 Model transformer with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, - MINIMAX_TEXT_01_START_DOCSTRING, + MINI_MAX_TEXT01_START_DOCSTRING, ) class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel): def __init__(self, config): @@ -1395,7 +1420,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, @@ -1458,7 +1483,7 @@ def forward( The MiniMaxText01 Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). """, - MINIMAX_TEXT_01_START_DOCSTRING, + MINI_MAX_TEXT01_START_DOCSTRING, ) class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel): base_model_prefix = "model" @@ -1477,7 +1502,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, From a93ee3fc47a321cb3b9ad323df4097e737abcd75 Mon Sep 17 00:00:00 2001 From: geetu040 Date: Mon, 27 Jan 2025 12:12:08 +0500 Subject: [PATCH 07/11] fix modular conversion --- .../modeling_minimax_text_01.py | 122 +++++++-------- .../modular_minimax_text_01.py | 147 ++++++++++++++++-- 2 files changed, 186 insertions(+), 83 deletions(-) diff --git a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py index 9e33a18a58a3..0cb1da75fb5a 100644 --- a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py @@ -51,6 +51,7 @@ logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_minimax_text_01 import MiniMaxText01Config @@ -61,6 +62,26 @@ _CONFIG_FOR_DOC = "MixtralConfig" +class MiniMaxText01RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniMaxText01RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + class MiniMaxText01LightningAttentionDecay(nn.Module): def __init__(self, config: MiniMaxText01Config, layer_idx: int): super().__init__() @@ -115,7 +136,7 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): self.block_size = config.block_size self.act_fn = ACT2FN[config.hidden_act] - self.norm = MixtralRMSNorm(self.head_dim * self.num_heads) + self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads) self.qkv_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim * 3, bias=False) self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) self.output_gate = nn.Linear(config.hidden_size, self.num_heads * self.num_heads, bias=False) @@ -143,7 +164,9 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - # TODO: apply attention_mask + # apply attention_mask + if attention_mask is not None: + value_states = value_states.masked_fill((1 - attention_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0) query_states = F.pad(query_states, (0, 0, 0, padding)) key_states = F.pad(key_states, (0, 0, 0, padding)) @@ -259,13 +282,6 @@ def eager_attention_forward( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - # print() - # ic(module.layer_idx) - # show_tensor(query, False, True) - # show_tensor(key_states, False, True) - # show_tensor(value_states, False, True) - # show_tensor(attn_weights, False, True) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) @@ -310,23 +326,11 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - # print(self.layer_idx) - # show_tensor(query_states, end=False, only_shapes=False) - # show_tensor(key_states, end=False, only_shapes=True) - # show_tensor(value_states, end=True, only_shapes=True) - - # print() - # print() - # ic(self.layer_idx) - # show_tensor(key_states, False, True) - if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # show_tensor(key_states, False, True) - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -351,10 +355,6 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - # ic(self.layer_idx) - # show_tensor(attn_output, False, True) - return attn_output, attn_weights @@ -444,26 +444,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -class MiniMaxText01RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - MiniMaxText01RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - class MiniMaxText01DecoderLayer(nn.Module): def __init__(self, config: MiniMaxText01Config, layer_idx: int): super().__init__() @@ -483,8 +463,9 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): self.layernorm_lightning_attention_beta = config.layernorm_lightning_attention_beta self.layernorm_mlp_alpha = config.layernorm_mlp_alpha self.layernorm_mlp_beta = config.layernorm_mlp_beta + self.attn_type = config.attn_type_list[layer_idx] - if config.attn_type_list[layer_idx] == 0: + if self.attn_type == 0: self.self_attn = MiniMaxText01LightningAttention(config, layer_idx) self.layernorm_alpha = self.layernorm_lightning_attention_alpha self.layernorm_beta = self.layernorm_lightning_attention_beta @@ -496,7 +477,7 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, @@ -509,7 +490,7 @@ def forward( """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + attention_mask (`Tuple[torch.Tensor, torch.Tensor]`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): @@ -538,7 +519,7 @@ def forward( hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, - attention_mask=attention_mask, + attention_mask=attention_mask[self.attn_type], position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, @@ -592,7 +573,7 @@ def _dynamic_frequency_update(self, position_ids, device): 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth_dynamic_frequency_update + if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len @@ -628,7 +609,7 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -MINI_MAX_TEXT01_START_DOCSTRING = r""" +MINIMAX_TEXT_01_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) @@ -647,7 +628,7 @@ def forward(self, x, position_ids): @add_start_docstrings( "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.", - MINI_MAX_TEXT01_START_DOCSTRING, + MINIMAX_TEXT_01_START_DOCSTRING, ) class MiniMaxText01PreTrainedModel(PreTrainedModel): config_class = MiniMaxText01Config @@ -661,6 +642,7 @@ class MiniMaxText01PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -674,7 +656,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -MINI_MAX_TEXT01_INPUTS_DOCSTRING = r""" +MINIMAX_TEXT_01_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -751,7 +733,7 @@ def _init_weights(self, module): @add_start_docstrings( "The bare MiniMaxText01 Model outputting raw hidden-states without any specific head on top.", - MINI_MAX_TEXT01_START_DOCSTRING, + MINIMAX_TEXT_01_START_DOCSTRING, ) class MiniMaxText01Model(MiniMaxText01PreTrainedModel): """ @@ -783,7 +765,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, @@ -820,7 +802,6 @@ def forward( ) use_cache = False - # TODO: raise exception here? if use_cache and past_key_values is None: past_key_values = DynamicCache() @@ -838,6 +819,8 @@ def forward( causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + # TODO: commments + causal_mask = (attention_mask, causal_mask) hidden_states = inputs_embeds @@ -1173,7 +1156,8 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1189,7 +1173,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1199,10 +1183,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1222,7 +1208,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - # ic(input_ids.shape, input_ids) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -1252,7 +1237,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1299,7 +1285,7 @@ def forward( padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, - MINI_MAX_TEXT01_START_DOCSTRING, + MINIMAX_TEXT_01_START_DOCSTRING, ) class MiniMaxText01ForSequenceClassification(MiniMaxText01PreTrainedModel): def __init__(self, config): @@ -1317,7 +1303,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1395,7 +1381,7 @@ def forward( The MiniMaxText01 Model transformer with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, - MINI_MAX_TEXT01_START_DOCSTRING, + MINIMAX_TEXT_01_START_DOCSTRING, ) class MiniMaxText01ForTokenClassification(MiniMaxText01PreTrainedModel): def __init__(self, config): @@ -1420,7 +1406,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, @@ -1483,7 +1469,7 @@ def forward( The MiniMaxText01 Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). """, - MINI_MAX_TEXT01_START_DOCSTRING, + MINIMAX_TEXT_01_START_DOCSTRING, ) class MiniMaxText01ForQuestionAnswering(MiniMaxText01PreTrainedModel): base_model_prefix = "model" @@ -1502,7 +1488,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(MINI_MAX_TEXT01_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py index fcbb9def8257..f338d63115d8 100644 --- a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py @@ -28,16 +28,22 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...modeling_outputs import ( + BaseModelOutputWithPast, + MoeModelOutputWithPast, +) from ...processing_utils import Unpack from ...utils import ( logging, + add_start_docstrings_to_model_forward, ) from ..mixtral.modeling_mixtral import ( eager_attention_forward, + MIXTRAL_INPUTS_DOCSTRING, MixtralRMSNorm, MixtralAttention, MixtralDecoderLayer, @@ -233,6 +239,10 @@ def __init__( ) +class MiniMaxText01RMSNorm(MixtralRMSNorm): + pass + + class MiniMaxText01LightningAttentionDecay(nn.Module): def __init__(self, config: MiniMaxText01Config, layer_idx: int): super().__init__() @@ -288,7 +298,7 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): self.block_size = config.block_size self.act_fn = ACT2FN[config.hidden_act] - self.norm = MixtralRMSNorm(self.head_dim * self.num_heads) + self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads) self.qkv_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim * 3, bias=False) self.out_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) self.output_gate = nn.Linear(config.hidden_size, self.num_heads * self.num_heads, bias=False) @@ -471,6 +481,9 @@ def forward( return outputs +MINIMAX_TEXT_01_INPUTS_DOCSTRING = MIXTRAL_INPUTS_DOCSTRING + + class MiniMaxText01Model(MixtralModel): def __init__(self, config: MiniMaxText01Config): super().__init__(config) @@ -478,22 +491,126 @@ def __init__(self, config: MiniMaxText01Config): [MiniMaxText01DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - def _update_causal_mask( + @add_start_docstrings_to_model_forward(MINIMAX_TEXT_01_INPUTS_DOCSTRING) + def forward( self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - causal_mask = super()._update_causal_mask( - attention_mask=attention_mask, - input_tensor=input_tensor, - cache_position=cache_position, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + # TODO: commments + causal_mask = (attention_mask, causal_mask) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = MoeModelOutputWithPast( + last_hidden_state=hidden_states, past_key_values=past_key_values, - output_attentions=output_attentions, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, ) - return (attention_mask, causal_mask) + return output if return_dict else output.to_tuple() class MiniMaxText01ForCausalLM(MixtralForCausalLM): From 92f79639ca9a543e5581219e84e04fb4337daf66 Mon Sep 17 00:00:00 2001 From: geetu040 Date: Mon, 27 Jan 2025 14:23:35 +0500 Subject: [PATCH 08/11] pass both attention masks instead of tuple --- .../minimax_text_01/modular_minimax_text_01.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py index f338d63115d8..4920baaa13c9 100644 --- a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py @@ -410,7 +410,8 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, @@ -423,8 +424,10 @@ def forward( """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`Tuple[torch.Tensor, torch.Tensor]`, *optional*): attention mask of size + attention_mask (`torch.Tensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. + causal_mask (`torch.Tensor`, *optional*): causal attention mask of size + `(batch, 1, query_length, key_value_length)` where padding elements are indicated by 0. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -449,10 +452,11 @@ def forward( residual = hidden_states # Self Attention + attention_mask = attention_mask if self.attn_type==0 else causal_mask hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, - attention_mask=attention_mask[self.attn_type], + attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, @@ -545,8 +549,6 @@ def forward( causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - # TODO: commments - causal_mask = (attention_mask, causal_mask) hidden_states = inputs_embeds @@ -566,6 +568,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + attention_mask, causal_mask, position_ids, past_key_values, @@ -578,7 +581,8 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, + causal_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, From ecdc0eb816f11ed80f9f073da01c9a49627d0db0 Mon Sep 17 00:00:00 2001 From: geetu040 Date: Mon, 27 Jan 2025 15:02:59 +0500 Subject: [PATCH 09/11] formatting --- docs/source/en/index.md | 2 +- docs/source/en/model_doc/minimax_text_01.md | 0 .../configuration_minimax_text_01.py | 10 +++++ .../modeling_minimax_text_01.py | 16 +++++--- .../modular_minimax_text_01.py | 38 ++++++++++++------- src/transformers/utils/dummy_pt_objects.py | 6 +-- tests/models/minimax_text_01/__init__.py | 0 .../test_image_processing_minimax_text_01.py | 0 .../test_modeling_minimax_text_01.py | 0 utils/not_doctested.txt | 6 +-- 10 files changed, 51 insertions(+), 27 deletions(-) create mode 100644 docs/source/en/model_doc/minimax_text_01.md create mode 100644 tests/models/minimax_text_01/__init__.py create mode 100644 tests/models/minimax_text_01/test_image_processing_minimax_text_01.py create mode 100644 tests/models/minimax_text_01/test_modeling_minimax_text_01.py diff --git a/docs/source/en/index.md b/docs/source/en/index.md index d3bdc304078e..bc48e96f7c8f 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -226,9 +226,9 @@ Flax), PyTorch, and/or TensorFlow. | [Megatron-GPT2](model_doc/megatron_gpt2) | ✅ | ✅ | ✅ | | [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ | | [Mimi](model_doc/mimi) | ✅ | ❌ | ❌ | +| [MiniMaxText01](model_doc/minimax_text_01) | ✅ | ❌ | ❌ | | [Mistral](model_doc/mistral) | ✅ | ✅ | ✅ | | [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ | -| [MiniMaxText01](model_doc/minimax_text_01) | ✅ | ❌ | ❌ | | [Mllama](model_doc/mllama) | ✅ | ❌ | ❌ | | [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ | | [MMS](model_doc/mms) | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/minimax_text_01.md b/docs/source/en/model_doc/minimax_text_01.md new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py b/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py index a1235ffbb278..9e350bf58c2b 100644 --- a/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/configuration_minimax_text_01.py @@ -51,6 +51,7 @@ class MiniMaxText01Config(PretrainedConfig): converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + head_dim (``, *optional*): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to `4096*32`): @@ -89,6 +90,15 @@ class MiniMaxText01Config(PretrainedConfig): The aux loss factor for the total loss. router_jitter_noise (`float`, *optional*, defaults to 0.0): Amount of noise to add to the router. + attn_type_list (``, *optional*): + block_size (``, *optional*, defaults to 256): + residual_post_norm (``, *optional*, defaults to `False`): + layernorm_attention_alpha (``, *optional*, defaults to 1): + layernorm_attention_beta (``, *optional*, defaults to 1): + layernorm_lightning_attention_alpha (``, *optional*, defaults to 1): + layernorm_lightning_attention_beta (``, *optional*, defaults to 1): + layernorm_mlp_alpha (``, *optional*, defaults to 1): + layernorm_mlp_beta (``, *optional*, defaults to 1): ```python >>> from transformers import MiniMaxText01Model, MiniMaxText01Config diff --git a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py index 0cb1da75fb5a..232acadbf06c 100644 --- a/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modeling_minimax_text_01.py @@ -477,7 +477,8 @@ def __init__(self, config: MiniMaxText01Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, @@ -490,8 +491,10 @@ def forward( """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`Tuple[torch.Tensor, torch.Tensor]`, *optional*): attention mask of size + attention_mask (`torch.Tensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. + causal_mask (`torch.Tensor`, *optional*): causal attention mask of size + `(batch, 1, query_length, key_value_length)` where padding elements are indicated by 0. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -516,10 +519,11 @@ def forward( residual = hidden_states # Self Attention + attention_mask = attention_mask if self.attn_type == 0 else causal_mask hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, - attention_mask=attention_mask[self.attn_type], + attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, @@ -819,8 +823,6 @@ def forward( causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - # TODO: commments - causal_mask = (attention_mask, causal_mask) hidden_states = inputs_embeds @@ -840,6 +842,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + attention_mask, causal_mask, position_ids, past_key_values, @@ -852,7 +855,8 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, + causal_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, diff --git a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py index 4920baaa13c9..0d9f1a1e8348 100644 --- a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py @@ -91,6 +91,7 @@ class MiniMaxText01Config(PretrainedConfig): converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + head_dim (``, *optional*): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to `4096*32`): @@ -129,6 +130,15 @@ class MiniMaxText01Config(PretrainedConfig): The aux loss factor for the total loss. router_jitter_noise (`float`, *optional*, defaults to 0.0): Amount of noise to add to the router. + attn_type_list (``, *optional*): + block_size (``, *optional*, defaults to 256): + residual_post_norm (``, *optional*, defaults to `False`): + layernorm_attention_alpha (``, *optional*, defaults to 1): + layernorm_attention_beta (``, *optional*, defaults to 1): + layernorm_lightning_attention_alpha (``, *optional*, defaults to 1): + layernorm_lightning_attention_beta (``, *optional*, defaults to 1): + layernorm_mlp_alpha (``, *optional*, defaults to 1): + layernorm_mlp_beta (``, *optional*, defaults to 1): ```python >>> from transformers import MiniMaxText01Model, MiniMaxText01Config @@ -172,8 +182,8 @@ def __init__( output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.0, - attn_type_list = None, - block_size = 256, + attn_type_list=None, + block_size=256, residual_post_norm=False, layernorm_attention_alpha=1, layernorm_attention_beta=1, @@ -208,8 +218,9 @@ def __init__( # use softmax-attention after `interval` lightning-attentions interval = num_hidden_layers // 10 self.attn_type_list = ( - [1 if i%interval==interval-1 else 0 for i in range(num_hidden_layers)] - if attn_type_list is None else attn_type_list + [1 if i % interval == interval - 1 else 0 for i in range(num_hidden_layers)] + if attn_type_list is None + else attn_type_list ) self.block_size = block_size @@ -260,8 +271,8 @@ def forward(self, x, seq_len): num_heads_range = torch.arange(self.num_heads).to(x) + 1 block_size_range = torch.arange(self.block_size).to(x) + 1 - slope_rate = ( 1 / (2 ** (8/self.num_heads)) ) ** num_heads_range - slope_rate *= 1 - self.layer_idx / (self.num_hidden_layers - 1) + 1e-5 # check small addition + slope_rate = (1 / (2 ** (8 / self.num_heads))) ** num_heads_range + slope_rate *= 1 - self.layer_idx / (self.num_hidden_layers - 1) + 1e-5 # check small addition slope_rate = slope_rate[:, None, None] query_decay = torch.exp(-slope_rate * block_size_range[:, None]) @@ -270,7 +281,7 @@ def forward(self, x, seq_len): key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None])) key_decay = key_decay[:, None, :, :] key_decay = key_decay.repeat(1, num_blocks, 1, 1) - key_decay[:, -1, :self.block_size-padding] = key_decay[:, -1, padding:] + key_decay[:, -1, : self.block_size - padding] = key_decay[:, -1, padding:] diagonal_decay = block_size_range[:, None] - block_size_range[None, :] diagonal_decay = slope_rate * diagonal_decay[None, :, :] @@ -278,10 +289,9 @@ def forward(self, x, seq_len): diagonal_decay = torch.exp(diagonal_decay) diagonal_decay = diagonal_decay[:, None, :, :] - block_lengths = torch.cat(( - torch.full((num_blocks-1,), self.block_size), - torch.tensor([self.block_size-padding]) - )).to(x) + block_lengths = torch.cat( + (torch.full((num_blocks - 1,), self.block_size), torch.tensor([self.block_size - padding])) + ).to(x) block_decay = torch.exp(-slope_rate[:, None, :, :] * block_lengths[:, None, None]) return key_decay, query_decay, diagonal_decay, block_decay @@ -352,14 +362,14 @@ def forward( attn_weights_inter = torch.matmul((key_states * key_decay).transpose(-1, -2), value_states) attn_weights_inter = torch.cat([next_cache, attn_weights_inter], dim=2) for i in range(num_blocks): - attn_weights_inter[:, :, i+1, :, :] += attn_weights_inter[:, :, i, :, :] * block_decay[:, i, :, :] + attn_weights_inter[:, :, i + 1, :, :] += attn_weights_inter[:, :, i, :, :] * block_decay[:, i, :, :] next_cache = attn_weights_inter[:, :, -1, :, :] attn_weights_inter = attn_weights_inter[:, :, :-1, :, :] attn_output_inter = torch.matmul(query_states * query_decay, attn_weights_inter) # inter + intra attn_output = attn_output_inter + attn_output_intra - attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len+padding, self.head_dim) + attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len + padding, self.head_dim) attn_output = attn_output[:, :, :seq_len, :] # final output projection @@ -452,7 +462,7 @@ def forward( residual = hidden_states # Self Attention - attention_mask = attention_mask if self.attn_type==0 else causal_mask + attention_mask = attention_mask if self.attn_type == 0 else causal_mask hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index bf3102725da3..4562da7aa86b 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6315,14 +6315,14 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class MiniMaxText01ForCausalLM(metaclass=DummyObject): +class MiniMaxText01Model(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class MiniMaxText01ForQuestionAnswering(metaclass=DummyObject): +class MiniMaxText01ForCausalLM(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -6343,7 +6343,7 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class MiniMaxText01Model(metaclass=DummyObject): +class MiniMaxText01ForQuestionAnswering(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/tests/models/minimax_text_01/__init__.py b/tests/models/minimax_text_01/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/minimax_text_01/test_image_processing_minimax_text_01.py b/tests/models/minimax_text_01/test_image_processing_minimax_text_01.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/minimax_text_01/test_modeling_minimax_text_01.py b/tests/models/minimax_text_01/test_modeling_minimax_text_01.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index 0a9c635f1973..c6ac3c139e96 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -163,9 +163,9 @@ docs/source/en/model_doc/mega.md docs/source/en/model_doc/megatron-bert.md docs/source/en/model_doc/megatron_gpt2.md docs/source/en/model_doc/mgp-str.md +docs/source/en/model_doc/minimax_text_01.md docs/source/en/model_doc/mistral.md docs/source/en/model_doc/mixtral.md -docs/source/en/model_doc/minimax_text_01.md docs/source/en/model_doc/mluke.md docs/source/en/model_doc/mms.md docs/source/en/model_doc/mobilebert.md @@ -672,12 +672,12 @@ src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability. src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py src/transformers/models/mgp_str/configuration_mgp_str.py src/transformers/models/mgp_str/modeling_mgp_str.py +src/transformers/models/minimax_text_01/configuration_minimax_text_01.py +src/transformers/models/minimax_text_01/modeling_minimax_text_01.py src/transformers/models/mistral/configuration_mistral.py src/transformers/models/mistral/modeling_mistral.py src/transformers/models/mixtral/configuration_mixtral.py src/transformers/models/mixtral/modeling_mixtral.py -src/transformers/models/minimax_text_01/configuration_minimax_text_01.py -src/transformers/models/minimax_text_01/modeling_minimax_text_01.py src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py src/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py From 0027c0ea318be18e02d02d64ee107f77332c28c9 Mon Sep 17 00:00:00 2001 From: Shakib-IO Date: Wed, 29 Jan 2025 23:23:28 -0500 Subject: [PATCH 10/11] Updated Dynamic Cache --- .../modular_minimax_text_01.py | 54 +++++++++++++------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py index 0d9f1a1e8348..b72d4eb5c529 100644 --- a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py @@ -17,7 +17,7 @@ # TODO: remove these from icecream import ic -from pack_minimax import show_tensor +# from pack_minimax import show_tensor from typing import Callable, List, Optional, Tuple, Union @@ -348,24 +348,46 @@ def forward( key_states = key_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) value_states = value_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) - # TODO: get from past_key_value[layer_idx] - next_cache = torch.zeros(batch_size, self.num_heads, 1, self.head_dim, self.head_dim).to(value_states) + # Dynamic Cache + if past_key_value is None: + kv_cache = DynamicCache() + else: + kv_cache = past_key_value + + # update the kv_cache with the new key and value states + kv_cache.update( + key_states, value_states, layer_idx=self.layer_idx + ) + + # Retrieve the cached key and value states + cached_keys, cached_values = kv_cache[self.layer_idx] # Shape: [batch_size, num_heads, seq_len, head_dim] + + # # TODO: get from past_key_value[layer_idx] + # next_cache = torch.zeros(batch_size, self.num_heads, 1, self.head_dim, self.head_dim).to(value_states) # get decay factors key_decay, query_decay, diagonal_decay, block_decay = self.decay_factors(query_states, seq_len) # intra: ( Q @ K.T ) @ V -> QK * V - attn_weights_intra = torch.matmul(query_states, key_states.transpose(-1, -2)) - attn_output_intra = torch.matmul(attn_weights_intra * diagonal_decay, value_states) + # attn_weights_intra = torch.matmul(query_states, key_states.transpose(-1, -2)) + # Calculate the attention weights using the cached key and value states + attn_weights_intra = torch.matmul(query_states, cached_keys.transpose(-1, -2)) + attn_output_intra = torch.matmul(attn_weights_intra * diagonal_decay, cached_values) # inter: Q @ ( K.T @ V ) -> Q * KV - attn_weights_inter = torch.matmul((key_states * key_decay).transpose(-1, -2), value_states) - attn_weights_inter = torch.cat([next_cache, attn_weights_inter], dim=2) + attn_weights_inter = torch.matmul((cached_keys * key_decay).transpose(-1, -2), cached_values) + attn_weights_inter = torch.cat([torch.zeros_like(attn_weights_inter[:, :, :1]), attn_weights_inter], dim=2) + + # attn_weights_inter = torch.cat([next_cache, attn_weights_inter], dim=2) + for i in range(num_blocks): attn_weights_inter[:, :, i + 1, :, :] += attn_weights_inter[:, :, i, :, :] * block_decay[:, i, :, :] - next_cache = attn_weights_inter[:, :, -1, :, :] - attn_weights_inter = attn_weights_inter[:, :, :-1, :, :] - attn_output_inter = torch.matmul(query_states * query_decay, attn_weights_inter) + + # next_cache = attn_weights_inter[:, :, -1, :, :] + # attn_weights_inter = attn_weights_inter[:, :, :-1, :, :] + # attn_output_inter = torch.matmul(query_states * query_decay, attn_weights_inter) + + attn_output_inter = torch.matmul(query_states * query_decay, attn_weights_inter[:, :, :-1, :, :]) # inter + intra attn_output = attn_output_inter + attn_output_intra @@ -379,17 +401,17 @@ def forward( attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output attn_output = self.out_proj(attn_output) - # TODO: put to past_key_value[layer_idx] - next_cache + print("KV Cache:", kv_cache) + # # TODO: put to past_key_value[layer_idx] + # next_cache # TODO: remove these - print() - print(self.layer_idx) - print(next_cache) + # print() + # print(self.layer_idx) + # print("Next Cache:",next_cache) return attn_output, None - class MiniMaxText01Attention(MixtralAttention): pass From e117d2660e36800515fde33a3d16110d5fe9ba56 Mon Sep 17 00:00:00 2001 From: geetu040 Date: Fri, 31 Jan 2025 19:01:58 +0500 Subject: [PATCH 11/11] created MiniMaxText01Cache --- .../modular_minimax_text_01.py | 141 ++++++++++-------- 1 file changed, 76 insertions(+), 65 deletions(-) diff --git a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py index b72d4eb5c529..f9982857c8e2 100644 --- a/src/transformers/models/minimax_text_01/modular_minimax_text_01.py +++ b/src/transformers/models/minimax_text_01/modular_minimax_text_01.py @@ -17,7 +17,7 @@ # TODO: remove these from icecream import ic -# from pack_minimax import show_tensor +from pack_minimax import show_tensor from typing import Callable, List, Optional, Tuple, Union @@ -254,6 +254,18 @@ class MiniMaxText01RMSNorm(MixtralRMSNorm): pass +class MiniMaxText01Cache(DynamicCache): + def __init__(self): + super().__init__() + self.kv_cache: dict[int: torch.Tensor] = {} + + def set_kv_cache(self, kv_cache, layer_idx): + self.kv_cache[layer_idx] = kv_cache + + def get_kv_cache(self, layer_idx): + return self.kv_cache.get(layer_idx) + + class MiniMaxText01LightningAttentionDecay(nn.Module): def __init__(self, config: MiniMaxText01Config, layer_idx: int): super().__init__() @@ -336,63 +348,55 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - # apply attention_mask - if attention_mask is not None: - value_states = value_states.masked_fill((1 - attention_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0) + kv_cache = past_key_value.get_kv_cache(self.layer_idx) - query_states = F.pad(query_states, (0, 0, 0, padding)) - key_states = F.pad(key_states, (0, 0, 0, padding)) - value_states = F.pad(value_states, (0, 0, 0, padding)) + if kv_cache is None: + kv_cache = torch.zeros(batch_size, self.num_heads, 1, self.head_dim, self.head_dim).to(value_states) - query_states = query_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) - key_states = key_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) - value_states = value_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) - - # Dynamic Cache - if past_key_value is None: - kv_cache = DynamicCache() - else: - kv_cache = past_key_value - - # update the kv_cache with the new key and value states - kv_cache.update( - key_states, value_states, layer_idx=self.layer_idx - ) + # apply attention_mask + if attention_mask is not None: + value_states = value_states.masked_fill((1 - attention_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0) - # Retrieve the cached key and value states - cached_keys, cached_values = kv_cache[self.layer_idx] # Shape: [batch_size, num_heads, seq_len, head_dim] - - # # TODO: get from past_key_value[layer_idx] - # next_cache = torch.zeros(batch_size, self.num_heads, 1, self.head_dim, self.head_dim).to(value_states) + query_states = F.pad(query_states, (0, 0, 0, padding)) + key_states = F.pad(key_states, (0, 0, 0, padding)) + value_states = F.pad(value_states, (0, 0, 0, padding)) - # get decay factors - key_decay, query_decay, diagonal_decay, block_decay = self.decay_factors(query_states, seq_len) + query_states = query_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) + key_states = key_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) + value_states = value_states.reshape(batch_size, self.num_heads, num_blocks, self.block_size, self.head_dim) - # intra: ( Q @ K.T ) @ V -> QK * V - # attn_weights_intra = torch.matmul(query_states, key_states.transpose(-1, -2)) - # Calculate the attention weights using the cached key and value states - attn_weights_intra = torch.matmul(query_states, cached_keys.transpose(-1, -2)) - attn_output_intra = torch.matmul(attn_weights_intra * diagonal_decay, cached_values) + # get decay factors + key_decay, query_decay, diagonal_decay, block_decay = self.decay_factors(query_states, seq_len) - # inter: Q @ ( K.T @ V ) -> Q * KV - attn_weights_inter = torch.matmul((cached_keys * key_decay).transpose(-1, -2), cached_values) - attn_weights_inter = torch.cat([torch.zeros_like(attn_weights_inter[:, :, :1]), attn_weights_inter], dim=2) + # intra: ( Q @ K.T ) @ V -> QK * V + attn_weights_intra = torch.matmul(query_states, key_states.transpose(-1, -2)) + attn_output_intra = torch.matmul(attn_weights_intra * diagonal_decay, value_states) - # attn_weights_inter = torch.cat([next_cache, attn_weights_inter], dim=2) + # inter: Q @ ( K.T @ V ) -> Q * KV + attn_weights_inter = torch.matmul((key_states * key_decay).transpose(-1, -2), value_states) + attn_weights_inter = torch.cat([kv_cache, attn_weights_inter], dim=2) + for i in range(num_blocks): + attn_weights_inter[:, :, i + 1, :, :] += attn_weights_inter[:, :, i, :, :] * block_decay[:, i, :, :] + kv_cache = attn_weights_inter[:, :, -1, :, :] + attn_weights_inter = attn_weights_inter[:, :, :-1, :, :] + attn_output_inter = torch.matmul(query_states * query_decay, attn_weights_inter) - for i in range(num_blocks): - attn_weights_inter[:, :, i + 1, :, :] += attn_weights_inter[:, :, i, :, :] * block_decay[:, i, :, :] - - # next_cache = attn_weights_inter[:, :, -1, :, :] - # attn_weights_inter = attn_weights_inter[:, :, :-1, :, :] - # attn_output_inter = torch.matmul(query_states * query_decay, attn_weights_inter) - - attn_output_inter = torch.matmul(query_states * query_decay, attn_weights_inter[:, :, :-1, :, :]) - - # inter + intra - attn_output = attn_output_inter + attn_output_intra - attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len + padding, self.head_dim) - attn_output = attn_output[:, :, :seq_len, :] + # inter + intra + attn_output = attn_output_inter + attn_output_intra + attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len + padding, self.head_dim) + attn_output = attn_output[:, :, :seq_len, :] + else: + ratio = 0.23 # TODO: get from decay + attn_output = [] + for i in range(seq_len): + kv_cache = ratio * kv_cache + torch.einsum( + "... n d, ... n e -> ... d e", + key_states[:, :, i : i + 1], + value_states[:, :, i : i + 1], + ) + attn_output_i = torch.einsum("... n e, ... e d -> ... n d", query_states[:, :, i : i + 1], kv_cache) + attn_output.append(attn_output_i) + attn_output = torch.concat(attn_output, dim=-2) # final output projection attn_output = attn_output.transpose(1, 2) @@ -401,17 +405,17 @@ def forward( attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output attn_output = self.out_proj(attn_output) - print("KV Cache:", kv_cache) - # # TODO: put to past_key_value[layer_idx] - # next_cache + # update cache + past_key_value.set_kv_cache(kv_cache, self.layer_idx) # TODO: remove these - # print() - # print(self.layer_idx) - # print("Next Cache:",next_cache) + print() + print(self.layer_idx) + print(kv_cache) return attn_output, None + class MiniMaxText01Attention(MixtralAttention): pass @@ -443,7 +447,6 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, @@ -458,8 +461,6 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.Tensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - causal_mask (`torch.Tensor`, *optional*): causal attention mask of size - `(batch, 1, query_length, key_value_length)` where padding elements are indicated by 0. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -484,7 +485,6 @@ def forward( residual = hidden_states # Self Attention - attention_mask = attention_mask if self.attn_type == 0 else causal_mask hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, @@ -565,7 +565,7 @@ def forward( use_cache = False if use_cache and past_key_values is None: - past_key_values = DynamicCache() + past_key_values = MiniMaxText01Cache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -596,12 +596,17 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) + if decoder_layer.attn_type == 0: + # lightning attention uses original attention_mask, and uses it only for the first step + input_attention_mask = attention_mask + else: + input_attention_mask = causal_mask + if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, - causal_mask, + input_attention_mask, position_ids, past_key_values, output_attentions, @@ -613,8 +618,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, - causal_mask=causal_mask, + attention_mask=input_attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -654,6 +658,13 @@ def __init__(self, config): super().__init__(config) self.model = MiniMaxText01Model(config) + def _prepare_cache_for_generation( + self, generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + ): + if model_kwargs.get("past_key_values") is None: + model_kwargs["past_key_values"] = MiniMaxText01Cache() + super()._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device) + class MiniMaxText01ForSequenceClassification(MixtralForSequenceClassification): pass