From 28375450739018e9a777985d1510be89352d6966 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Mon, 17 Feb 2025 16:35:41 +0800 Subject: [PATCH] fix bugs --- mindone/models/modules/parallel/conv.py | 2 +- mindone/trainers/zero.py | 40 +++++++++++++++---------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/mindone/models/modules/parallel/conv.py b/mindone/models/modules/parallel/conv.py index 1d306ebff8..a3a46a5133 100644 --- a/mindone/models/modules/parallel/conv.py +++ b/mindone/models/modules/parallel/conv.py @@ -33,7 +33,7 @@ def set_param_wrapper(self, zero_stage, op_group, cell_type=None): split_op = ops.Split(0, op_group_size) if self.param_wrapper_w.need_rewrite: self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id]) - if self.net.has_bias: + if self.net.bias: self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type) if self.param_wrapper_b.need_rewrite: self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id]) diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 6cf3555d0b..f340982e38 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -313,30 +313,38 @@ def get_need_parameter_split(self): self.need_parameter_split = tuple(self.need_parameter_split) def split_params(self): - if not (self.zero_stage in [1, 2] and self.is_parallel): - _logger.info("No need to split optimizer parameters standalone.") + if not (self.zero_stage in [1, 2, 3] and self.is_parallel): return - _logger.info("Clone optimizer.parameters, will increase memory.") - # Because the first input of MindSpore optimizer must be ms.Parameter, - # copy optimizer.parameters for optimizer parameters update. - # It will increase 1/n parameters' memory. - self.optimizer.parameters = self.optimizer.parameters.clone(prefix="wrapper", init="same") - self.optimizer._parameters = self.optimizer.parameters - self.last_assign = True + if self.zero_stage in [1, 2]: + _logger.info("Clone optimizer.parameters, will increase memory.") + # Because the first input of MindSpore optimizer must be ms.Parameter, + # copy optimizer.parameters for optimizer parameters update. + # It will increase 1/n parameters' memory. + self.optimizer.parameters = self.optimizer.parameters.clone(prefix="wrapper", init="same") + self.optimizer._parameters = self.optimizer.parameters + self.last_assign = True param_tuples = self.get_optimizer_param_tuples() for i, param in enumerate(self.optimizer._parameters): _logger.debug(f"Split optimizer param {param.name} {param.shape}") - B = param.shape[0] - if param.parallel_optimizer and B >= self.op_group_size and B % self.op_group_size == 0: - _logger.debug(f"Do split with zero_stage {self.zero_stage}") - ori_shape = param.shape - param.assign_value(self.split_param(param)) - _logger.debug(f"Optimizer {param.name} from {ori_shape} to {param.shape}") + # If zero_stage is 3, the parameters in train network have been split, + # use parameter in param_tuples to get batch size. + _logger.debug(f"Do split with zero_stage {self.zero_stage}") + if self.zero_stage in [1, 2]: + B = param.shape[0] + if self.ori_parameters[i] and B >= self.op_group_size and B % self.op_group_size == 0: + param.parallel_optimizer = True + else: + param.parallel_optimizer = False + if param.parallel_optimizer: + if self.zero_stage in [1, 2]: + ori_shape = param.shape + param.assign_value(self.split_param(param)) + _logger.debug(f"Optimizer {param.name} from {ori_shape} to {param.shape}") for param_tuple in param_tuples: ori_shape = param_tuple[i].shape param_tuple[i].assign_value(self.split_param(param_tuple[i])) - _logger.debug(f"Optimizer {param_tuple[i].name} from {ori_shape} to {param_tuple[i].shape}") + _logger.debug(f"Optimizer {param_tuple[i].name} " f"from {ori_shape} to {param_tuple[i].shape}") def reduce_scatter_gradients(self, gradients): dtype = gradients[0].dtype