diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 2245113c9c..29d84ba26a 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -132,7 +132,14 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, my_weight_list = weight_list[rank::world_size] with torch.no_grad(): - torch.cat(my_weight_list, dim=partition_dim, out=weight) + if master_weight.shape[partition_dim] > per_partition_size: + torch.cat(my_weight_list, dim=partition_dim, out=weight) + else: + # when non-expert is tensor-parallel and expert is not tensor-parallel, + # per_partition_size is equal to master_weight.shape[partition_dim], + # so my_weight_list len is 0 except in 0 rank ,so we can not use torch.cat, + # we should use assign. + weight = master_weight if return_master_weight: return master_weight return None