Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoting committed Feb 17, 2025
1 parent e749110 commit e7c89b4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/tools/_toctree.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
- sections:
- local: zero
title: ZeROs
title: ZeRO
title: Get started
2 changes: 1 addition & 1 deletion mindone/models/modules/parallel/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def set_param_wrapper(self, zero_stage, op_group, cell_type=None):
split_op = ops.Split(0, op_group_size)
if self.param_wrapper_w.need_rewrite:
self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id])
if self.net.has_bias:
if self.net.bias:
self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type)
if self.param_wrapper_b.need_rewrite:
self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id])
Expand Down
38 changes: 23 additions & 15 deletions mindone/trainers/zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,26 +313,34 @@ def get_need_parameter_split(self):
self.need_parameter_split = tuple(self.need_parameter_split)

def split_params(self):
if not (self.zero_stage in [1, 2] and self.is_parallel):
_logger.info("No need to split optimizer parameters standalone.")
if not (self.zero_stage in [1, 2, 3] and self.is_parallel):
return
_logger.info("Clone optimizer.parameters, will increase memory.")
# Because the first input of MindSpore optimizer must be ms.Parameter,
# copy optimizer.parameters for optimizer parameters update.
# It will increase 1/n parameters' memory.
self.optimizer.parameters = self.optimizer.parameters.clone(prefix="wrapper", init="same")
self.optimizer._parameters = self.optimizer.parameters
self.last_assign = True
if self.zero_stage in [1, 2]:
_logger.info("Clone optimizer.parameters, will increase memory.")
# Because the first input of MindSpore optimizer must be ms.Parameter,
# copy optimizer.parameters for optimizer parameters update.
# It will increase 1/n parameters' memory.
self.optimizer.parameters = self.optimizer.parameters.clone(prefix="wrapper", init="same")
self.optimizer._parameters = self.optimizer.parameters
self.last_assign = True

param_tuples = self.get_optimizer_param_tuples()
for i, param in enumerate(self.optimizer._parameters):
_logger.debug(f"Split optimizer param {param.name} {param.shape}")
B = param.shape[0]
if param.parallel_optimizer and B >= self.op_group_size and B % self.op_group_size == 0:
_logger.debug(f"Do split with zero_stage {self.zero_stage}")
ori_shape = param.shape
param.assign_value(self.split_param(param))
_logger.debug(f"Optimizer {param.name} from {ori_shape} to {param.shape}")
# If zero_stage is 3, the parameters in train network have been split,
# use parameter in param_tuples to get batch size.
_logger.debug(f"Do split with zero_stage {self.zero_stage}")
if self.zero_stage in [1, 2]:
B = param.shape[0]
if self.ori_parameters[i] and B >= self.op_group_size and B % self.op_group_size == 0:
param.parallel_optimizer = True
else:
param.parallel_optimizer = False
if param.parallel_optimizer:
if self.zero_stage in [1, 2]:
ori_shape = param.shape
param.assign_value(self.split_param(param))
_logger.debug(f"Optimizer {param.name} from {ori_shape} to {param.shape}")
for param_tuple in param_tuples:
ori_shape = param_tuple[i].shape
param_tuple[i].assign_value(self.split_param(param_tuple[i]))
Expand Down

0 comments on commit e7c89b4

Please sign in to comment.