From bc90220086fa0ee8b15be5ee2c7fbb178e3d65fd Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 4 Aug 2024 10:03:53 +0200 Subject: [PATCH] Upgrade Transformers to v4.43.x (#727) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes required for sync: - re-copy Llama & Beit attention - add clip sdp & flash attn - fix tie_weights method - upgrade torch version in tests --------- Co-authored-by: Leon Engländer --- .github/workflows/tests_torch.yml | 8 +- hf_transformers | 2 +- setup.py | 4 +- src/adapters/heads/model_mixin.py | 2 + src/adapters/models/beit/modeling_beit.py | 11 +- src/adapters/models/clip/modeling_clip.py | 167 +++++++++++++++++++- src/adapters/models/llama/modeling_llama.py | 65 ++++++-- tests/methods/base.py | 2 +- 8 files changed, 236 insertions(+), 25 deletions(-) diff --git a/.github/workflows/tests_torch.yml b/.github/workflows/tests_torch.yml index 668beb9e62..f7a394ce4a 100644 --- a/.github/workflows/tests_torch.yml +++ b/.github/workflows/tests_torch.yml @@ -39,7 +39,7 @@ jobs: key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} - name: Install run: | - pip install torch==2.1.2 + pip install torch==2.3 pip install .[quality] - name: Check Quality and Repo Consistency run: | @@ -62,7 +62,7 @@ jobs: ${{ runner.os }}-pip- - name: Install run: | - pip install torch==2.1.2 + pip install torch==2.3 pip install .[sklearn,testing,sentencepiece] - name: Test run: | @@ -85,7 +85,7 @@ jobs: ${{ runner.os }}-pip- - name: Install run: | - pip install torch==2.1.2 + pip install torch==2.3 pip install .[sklearn,testing,sentencepiece] - name: Test run: | @@ -108,7 +108,7 @@ jobs: ${{ runner.os }}-pip- - name: Install run: | - pip install torch==2.1.2 + pip install torch==2.3 pip install .[sklearn,testing,sentencepiece] pip install conllu seqeval - name: Test Examples diff --git a/hf_transformers b/hf_transformers index fc35907f95..47c29ccfaf 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit fc35907f95459d7a6c5281dfadd680b6f7b620e3 +Subproject commit 47c29ccfaf56947d845971a439cbe75a764b63d7 diff --git a/setup.py b/setup.py index 78c526fdf6..39d994e999 100644 --- a/setup.py +++ b/setup.py @@ -57,8 +57,8 @@ "sphinx-intl==2.1.0", "sphinx-multiversion==0.2.4", "timeout-decorator", - "torch>=1.10,!=1.12.0", - "transformers~=4.42.4", + "torch", + "transformers~=4.43.3", ] diff --git a/src/adapters/heads/model_mixin.py b/src/adapters/heads/model_mixin.py index 11a194ef92..9a27bbd764 100644 --- a/src/adapters/heads/model_mixin.py +++ b/src/adapters/heads/model_mixin.py @@ -134,6 +134,8 @@ def tie_weights(self): self = getattr(self, self.base_model_prefix) self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) + super().tie_weights() + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): old_embeddings = self.get_input_embeddings() new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) diff --git a/src/adapters/models/beit/modeling_beit.py b/src/adapters/models/beit/modeling_beit.py index 6e56d2b864..865fcdeae5 100644 --- a/src/adapters/models/beit/modeling_beit.py +++ b/src/adapters/models/beit/modeling_beit.py @@ -35,6 +35,7 @@ def forward( output_attentions: bool = False, relative_position_bias: Optional["BeitRelativePositionBias"] = None, interpolate_pos_encoding: bool = False, + resolution: Optional[Tuple[int]] = None, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: mixed_query_layer = self.query(hidden_states) @@ -51,9 +52,11 @@ def forward( # Add relative position bias if present. if self.relative_position_bias is not None: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) attention_scores = attention_scores + self.relative_position_bias( - interpolate_pos_encoding, attention_scores.shape[2] - ).unsqueeze(0) + window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) # Add shared relative position bias if provided. if relative_position_bias is not None: @@ -89,8 +92,9 @@ def forward( hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - relative_position_bias: Optional[BeitRelativePositionBias] = None, + relative_position_bias: Optional["BeitRelativePositionBias"] = None, interpolate_pos_encoding: bool = False, + resolution: Optional[Tuple[int]] = None, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention @@ -98,6 +102,7 @@ def forward( output_attentions=output_attentions, relative_position_bias=relative_position_bias, interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights diff --git a/src/adapters/models/clip/modeling_clip.py b/src/adapters/models/clip/modeling_clip.py index fecbb105c8..7328e532c4 100644 --- a/src/adapters/models/clip/modeling_clip.py +++ b/src/adapters/models/clip/modeling_clip.py @@ -21,11 +21,25 @@ import torch.utils.checkpoint from torch import nn -from transformers.models.clip.modeling_clip import CLIPAttention, CLIPEncoderLayer +from transformers.models.clip.modeling_clip import ( + CLIPAttention, + CLIPEncoderLayer, + CLIPFlashAttention2, + CLIPSdpaAttention, +) +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2 +from transformers.utils import is_flash_attn_2_available, logging + + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward from .mixin_clip import CLIPAttentionAdaptersMixin, CLIPEncoderLayerAdaptersMixin +logger = logging.get_logger(__name__) + + class CLIPAttentionWithAdapters(CLIPAttentionAdaptersMixin, CLIPAttention): def forward( self, @@ -46,9 +60,11 @@ def forward( proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + # >>> START AH Changes <<< key_states, value_states, attention_mask = self.prefix_tuning( key_states, value_states, hidden_states, attention_mask ) + # >>> END AH Changes <<< key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) @@ -115,6 +131,155 @@ def forward( return attn_output, attn_weights_reshaped +class CLIPFlashAttention2WithAdapters(CLIPAttentionAdaptersMixin, CLIPFlashAttention2): + # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + output_attentions = False + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # >>> START AH Changes <<< + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + # >>> END AH Changes <<< + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=causal_attention_mask is not None, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class CLIPSdpaAttentionWithAdapters(CLIPAttentionAdaptersMixin, CLIPSdpaAttention): + # Adapted from CLIPAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " + "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " + "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can " + 'be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + if attention_mask is not None and causal_attention_mask is not None: + attn_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attn_mask = causal_attention_mask + else: + attn_mask = attention_mask + + bsz, tgt_len, embed_dim = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + # >>> START AH Changes <<< + key_states, value_states, attn_mask = self.prefix_tuning(key_states, value_states, hidden_states, attn_mask) + # >>> END AH Changes <<< + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially. + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attn_mask, + dropout_p=self.dropout if self.training else 0.0, + scale=self.scale, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + class CLIPEncoderLayerWithAdapters(CLIPEncoderLayerAdaptersMixin, CLIPEncoderLayer): def forward( self, diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index f8d01df50b..461cdde2b8 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from transformers.cache_utils import Cache, StaticCache +from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -57,6 +58,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -94,8 +96,16 @@ def forward( (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) # >>> END AH Changes <<< - past_key_value = getattr(self, "past_key_value", past_key_value) - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -133,7 +143,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) @@ -158,7 +168,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -189,7 +199,16 @@ def forward( (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) # >>> END AH Changes <<< - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -221,7 +240,7 @@ def forward( # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype - if input_dtype == torch.float32 or key_states.dtype == torch.float32: + if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized @@ -240,11 +259,19 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -265,6 +292,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -282,6 +311,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -301,7 +331,16 @@ def forward( (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) # >>> END AH Changes <<< - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -332,8 +371,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -346,7 +385,7 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) diff --git a/tests/methods/base.py b/tests/methods/base.py index fd445a64c6..6ede68f2f3 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -245,7 +245,7 @@ def run_full_model_load_test(self, adapter_config): output1 = model1(**input_data) output2 = model2(**input_data) self.assertEqual(len(output1), len(output2)) - self.assertTrue(torch.equal(output1[0], output2[0])) + self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4)) def trainings_run(self, model, lr=1.0, steps=8): # setup dataset