Skip to content

Commit

Permalink
[feature] add deepseekv2 edp support
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincent-syr committed Dec 16, 2024
1 parent 1c2abf7 commit e28be90
Show file tree
Hide file tree
Showing 13 changed files with 497 additions and 53 deletions.
1 change: 1 addition & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, kvargs):
self.quant_type = kvargs.get("quant_type", None)
self.quant_cfg_path = kvargs.get("quant_cfg", None)
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
self.expert_parallel_mode = kvargs.get("expert_parallel_mode", "etp")

self._init_datatype()
self._init_config()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class TransformerLayerInferTpl(TransformerLayerInfer):
""" """

def __init__(self, layer_num, tp_rank, world_size, network_config, mode):
def __init__(self, layer_num, tp_rank, world_size, network_config, mode, tp_split=True):
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
# need to set by subclass
self.eps_ = 1e-5
Expand All @@ -21,6 +21,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode):
self.tp_o_head_num_ = -1
self.head_dim_ = -1
self.embed_dim_ = -1
self.tp_split_ = tp_split
return

def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
Expand Down Expand Up @@ -79,7 +80,7 @@ def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_w
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
input_embding.add_(o.view(-1, self.embed_dim_))
return
Expand All @@ -88,7 +89,7 @@ def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return
Expand All @@ -102,7 +103,7 @@ def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_wei
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
input_embding.add_(o.view(-1, self.embed_dim_))
return
Expand All @@ -111,7 +112,7 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return
Expand All @@ -125,7 +126,7 @@ def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateIn
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
input_embding.add_(o.view(-1, self.embed_dim_))
return
Expand All @@ -134,7 +135,7 @@ def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, l
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.world_size_ > 1:
if self.world_size_ > 1 and self.tp_split_:
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
MultiCOLMMWeight,
ROWBMMWeight,
COLBMMWeight,
MultiCOLMMWeightNoTp,
ROWBMMWeightNoTp,
COLBMMWeightNoTp,
COLMMWeightNoTp
)
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight import FusedMoeWeight
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class FusedMoeWeight(BaseWeight):
def __init__(
self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type
self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type, expert_parallel_mode="etp"
):
super().__init__()
assert HAS_VLLM, "vllm is not installed, you can't use FusedMoeWeight"
Expand All @@ -33,6 +33,7 @@ def __init__(
self.expert_down_proj_etp = None
self.w2_list = [None] * self.n_routed_experts
self.quant_method = None
self.expert_parallel_mode = expert_parallel_mode
self.lock = threading.Lock()

def set_quant_method(self, quant_method):
Expand Down Expand Up @@ -159,7 +160,7 @@ def _load_hf_weights_etp(self, weights):
self.expert_down_proj_etp[i_experts_ep, :] = self.experts_up_projs[i_experts_ep]

def load_hf_weights(self, weights):
if os.environ.get("ETP_MODE_ENABLED") == "true":
if self.expert_parallel_mode == "etp" or self.expert_parallel_mode == "edp":
self._load_hf_weights_etp(weights)
else:
for i_experts in range(self.n_routed_experts):
Expand Down Expand Up @@ -190,7 +191,7 @@ def _cuda(self, cpu_tensor):
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_)

def verify_load(self):
if os.environ.get("ETP_MODE_ENABLED") == "true":
if self.expert_parallel_mode == "etp" or self.expert_parallel_mode == "edp":
return True
else:
return self.w1 is not None and self.w2 is not None
64 changes: 64 additions & 0 deletions lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ def load_hf_weights(self, weights):
self._post_load_weights()
return

class COLMMWeightNoTp(MMWeight):
def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
super().__init__(weight_name, data_type, split_n_embed, bias_name)
self.start = 0
self.end = split_n_embed

def load_hf_weights(self, weights):
weight = None
if self.weight_name in weights:
weight = weights[self.weight_name].to(self.data_type_)
self.weight = weight[:, self.start : self.end]
if self.bias_name in weights:
bias = weights[self.bias_name]
self.bias = bias.to(self.data_type_).cuda(self.tp_rank_)
if weight is None:
return
self._post_load_weights()
return

class MultiMMWeight(MMWeightTpl):
def __init__(self, weight_names, data_type, split_n_embeds, bias_names=[]):
Expand Down Expand Up @@ -172,6 +190,21 @@ def load_hf_weights(self, weights):
self._fuse()
return

class MultiCOLMMWeightNoTp(MultiROWMMWeightNoTP):
def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]):
super().__init__(weight_names, data_type, split_n_embed, bias_names)

def load_hf_weights(self, weights):
weight = None
for i in range(len(self.weight_names)):
if self.weight_names[i] in weights:
weight = weights[self.weight_names[i]].to(self.data_type_)
self.weights[i] = weight[:, self.starts[i] : self.ends[i]]
if self.has_bias and self.bias_names[i] in weights:
bias = weights[self.bias_names[i]].to(self.data_type_)
self.biases[i] = bias[:, self.starts[i] : self.ends[i]]
self._fuse()
return

class BMMWeightTpl(BaseWeightTpl):
def __init__(self, data_type):
Expand Down Expand Up @@ -233,6 +266,19 @@ def __init__(
):
super().__init__(weight_name, data_type, split_n_embed, bias_name)

class ROWBMMWeightNoTp(BMMWeight):
load_hf_weights = ROWMMWeight.load_hf_weights

def __init__(
self,
weight_name,
data_type,
split_n_embed,
bias_name=None,
):
super().__init__(weight_name, data_type, split_n_embed, bias_name)
self.start = 0
self.end = split_n_embed

class COLBMMWeight(BMMWeight):
load_hf_weights = COLMMWeight.load_hf_weights
Expand All @@ -248,3 +294,21 @@ def __init__(

def _post_load_weights(self):
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)

class COLBMMWeightNoTp(BMMWeight):
load_hf_weights = COLMMWeightNoTp.load_hf_weights

def __init__(
self,
weight_name,
data_type,
split_n_embed,
bias_name=None,
):
super().__init__(weight_name, data_type, split_n_embed, bias_name)
self.start = 0
self.end = split_n_embed

def _post_load_weights(self):
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)

Loading

0 comments on commit e28be90

Please sign in to comment.