Skip to content

Commit

Permalink
Minor fixes for granite models (#1503)
Browse files Browse the repository at this point in the history
* Update granite.py

Grab residual multiplier directly from layer

* Update llama.py

Version should read >= 4.47.1 as that is the version requiring the changes

* Update granite.py

* Update llama.py

---------

Co-authored-by: Daniel Han <[email protected]>
  • Loading branch information
CoffeeVampir3 and danielhanchen authored Jan 7, 2025
1 parent 422c033 commit 83b48a8
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions unsloth/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ def GraniteDecoderLayer_fast_forward(
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
):
residual_multiplier = \
self.residual_multiplier \
if hasattr(self, "residual_multiplier") else \
self.config.residual_multiplier

if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
Expand All @@ -197,13 +202,13 @@ def GraniteDecoderLayer_fast_forward(
position_embeddings = position_embeddings,
_flag_for_generation=self._flag_for_generation,
)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)

# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(self.mlp, hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
else:
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
Expand All @@ -218,13 +223,13 @@ def GraniteDecoderLayer_fast_forward(
padding_mask=padding_mask,
position_embeddings = position_embeddings,
)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)

# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)
pass

outputs = (hidden_states,)
Expand Down Expand Up @@ -370,6 +375,10 @@ def GraniteModel_fast_forward_inference(
hidden_states = self.model.embed_tokens(input_ids)
hidden_states = hidden_states.to(self.config.torch_dtype)
hidden_states *= self.model.embedding_multiplier
residual_multiplier = \
self.residual_multiplier \
if hasattr(self, "residual_multiplier") else \
self.config.residual_multiplier

bsz, q_len, hd = hidden_states.shape
seq_len = past_key_values[0][0].shape[-2]
Expand Down Expand Up @@ -401,12 +410,12 @@ def GraniteModel_fast_forward_inference(
position_embeddings = position_embeddings,
)

hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)

residual = hidden_states
hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states)
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier)
hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier)

next_decoder_cache.append(present_key_value)
pass
Expand Down

0 comments on commit 83b48a8

Please sign in to comment.