-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepSeek_v3 support #1735
base: main
Are you sure you want to change the base?
DeepSeek_v3 support #1735
Conversation
@srajabos FYI there is an open PR to add Deepseek V3 to Transformers: huggingface/transformers#35926 We won't be able to rely on the Transformers implementation before Transformers v4.49 is released, but I thought this might be interesting to you. |
@regiss, I'll keep this as draft until verified with Transformers v4.49. |
Deepseek V3 (and hence R1) requriements.txt says the minimum version of transformer required is 4.46.3 |
@anishagartia, currently we are adding the model files and optimizing for Gaudi. Once we have performant data the plan is to get it in. Thanks for the link. |
Copied from transformers v4.48.2 for DeepSeek-R1 support. Delete after upgrade transformers v4.45.2 to v4.48
17f62bd
to
b2b1715
Compare
@yao-matrix @gyou2021 @IT-Forrest - kindly review the code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[explanatory] are just comments to help follow the hpu code. no changes required for those comments. sorry for spamming comments in this category, thought it might be useful for future readers going thru the change and for others looking ot port similar models
[clarifications] some question from my end. Sometimes these are marked with [minor] if they are minor nitpicks
from habana_frameworks.torch.hpex.kernels import FusedSDPA | ||
except ImportError: | ||
print("Not using HPU fused scaled dot-product attention kernel.") | ||
FusedSDPA = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[explanatory] Import hpu fused ops
|
||
def forward(self, hidden_states): | ||
if hidden_states.device.type == "hpu" and FusedRMSNorm: | ||
# mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[explanatory] use fused ops
self.register_buffer("inv_freq", inv_freq, persistent=False) | ||
|
||
# Build here to make `torch.jit.trace` work. | ||
self.max_seq_len_cached = max_position_embeddings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[explanatory] make it static (max_position_embeddings ) instead of updating depending on longest eq_len seen till now: "seq_len > self.max_seq_len_cached"
|
||
def apply_customized_rope(q, k, cos, sin, position_ids): | ||
if q.device.type == "hpu" and FusedRoPE: | ||
return FusedRoPE.apply( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[explanatory] fused hpu op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[clarification][minor] Could we call apply_customized_rope
here?
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | ||
return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous() | ||
|
||
def split_kv_b_proj(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[clarification] this is present only in deepseek attention (v2/v3). Can we add some comment about this?
self.q_absorb = kv_b_proj_weight[:, : self.qk_nope_head_dim, :].unsqueeze(0).transpose(0, 1) | ||
self.out_absorb = kv_b_proj_weight[:, self.qk_nope_head_dim :, :].unsqueeze(0) | ||
|
||
def compress_kv( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[clarification] this is present only in deepseek attention (v2/v3). Can we add some comment about this? In the original deepseek code this is not a function, any particular reason of functionify-ing this? just want to clarify if making this a function is a stylistic choice or there is some reason
key_states, value_states, self.layer_idx, cache_kwargs | ||
) | ||
# optimization | ||
if use_flash_attention and FusedSDPA is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[explanatory] hpu specific, similar to other modelling files in OH
|
||
past_key_values_length = 0 | ||
if past_key_values is not None: | ||
past_key_values_length = past_key_values[0][0].shape[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[explanatory] hpu kv cache management, similar to other OH models
and not self.training | ||
and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) | ||
): | ||
htcore.mark_step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[clarification] these marksteps breaks at layer boundaries is to fit model in memory? or some perf benefits?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mark_step is a memory optimization copied from llama #875
What does this PR do?
DeepSeek v3 support on OH
Before submitting