diff --git a/docs/index.md b/docs/index.md index 3e60f5014a..04ae4aac16 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,6 +32,14 @@ hide: [:octicons-arrow-right-24: Start tuning!](peft/index.md) +- :star2: __Tools__ + + --- + + Train Tools. Include Trainer, ZeRO, Image/Vedio data filtering strategy... + + [:octicons-arrow-right-24: Using it!](tools/zero.md) + - > :rocket: __Accelerate__ > --- diff --git a/docs/tools/_toctree.yml b/docs/tools/_toctree.yml new file mode 100644 index 0000000000..8c398ae5ae --- /dev/null +++ b/docs/tools/_toctree.yml @@ -0,0 +1,4 @@ +- sections: + - local: zero + title: ZeRO + title: Get started diff --git a/docs/tools/zero.md b/docs/tools/zero.md new file mode 100644 index 0000000000..ca96ea43cc --- /dev/null +++ b/docs/tools/zero.md @@ -0,0 +1,155 @@ +# Zero redundancy optimizer(ZeRO) on MindOne + +Zero Redundancy Optimizer (ZeRO) is a method for reducing memory usage under data parallelism strategy on paper: [ZeRO: ZeRO: Memory Optimization Towards Training A Trillion Parameter Models](https://arxiv.org/pdf/1910.02054.pdf). + +ZeRO eliminates memory redundancies in data and model parallel training while retaining low communication volume and high computational +granularity, allowing us to scale the model size proportional to the number of devices with sustained high efficiency. + +This tutorial walks you through how to generate faster and better with the ZeRO on MindOne. + +## Build Train Network With ZeRO + +Build a train network with ZeRO. + +```python +import mindspore as ms +from mindspore.communication import init +from mindspore.communication.management import GlobalComm +from mindone.trainers.zero import prepare_train_network + +# Initialize distributed environment +def init_env(mode, distribute): + ms.set_context(mode=mode) + if distribute: + init() + # ZeRO take effect must on DATA_PARALLEL + ms.set_auto_parallel_context( + parallel_mode=ms.ParallelMode.DATA_PARALLEL, + gradients_mean=True, + ) + +init_env(ms.GRAPH_MODE, True) + +# Net is your Train Network +net = Net() +# opt must be the subclass of MindSpore Optimizer. +opt = nn.AdamWeightDecay(net.trainable_params(), learning_rate=1e-3) + +# build a train network with ZeRO +train_net = prepare_train_network(net, opt, zero_stage=2, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP) +``` + +!!! tip + optimizer_parallel_group may not be GlobalComm.WORLD_COMM_GROUP. Using [create_group](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.communication.html#mindspore.communication.create_group) to create your optimizer_parallel_group. + +More details: + +::: mindone.trainers.zero.prepare_train_network + +[Here](https://github.com/mindspore-lab/mindone/blob/master/tests/others/test_zero.py) is an example. + +## Memory Analysis + +The memory consumption during the training can be divided into two main parts: + +- Residual states. Mainly includes activate functions, temporary buffers, and unavailable memory fragments. +- Model states. Mainly includes three parts: optimizer states(AdamW fp32), gradients(fp16), and parameters(fp16). The three are abbreviated as OPG. Assuming the number of model parameters is Φ, +the total model states is 2Φ(parameters) + 2Φ(gradients) + (4Φ + 4Φ + 4Φ)(optimizer states) = 16Φ, the AdamW states accounting for 75%. + +Residual states can be greatly reduced through [recompute](https://www.mindspore.cn/docs/en/master/model_train/parallel/recompute.html) and [model parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/strategy_select.html). +Then the ZeRO algorithm can be used to reduce model states. + +For the optimization of model states (removing redundancy), ZeRO uses the method of partitioning, which means that each card only stores 1/N data. + +ZeRO has three main optimization stages (as depicted in ZeRO paper Figure 1), which correspond to the partitioning of optimizer states, gradients, and parameters. When enabled cumulatively: + +1) Optimizer State Partitioning (Pos): Optimizer states are kept 1/N, the model parameters and gradients are still kept in full on each card. The model state of each card is 4Φ + 12Φ/N, when N is very large, it tend to 4Φ, that's the 1/4 original memory; +2) Add Gradient Partitioning (Pos+g): Add the gradients partitioning to 1/N, The model state of each card is 2Φ + (2Φ + 12Φ)/N, when N is very large, it tend to 2Φ, that's the 1/8 original memory; +3) Add Parameter Partitioning (Pos+g+p): Add the parameters partitioning to 1/N, The model state of each card is 16Φ/N, when N is very large, it tend to 0; + +Pos correspond to ZeRO-1, Pos+g correspond to ZeRO-2 and Pos+g+p correspond to ZeRO-3. + +## Communitition Analysis + +Currently, AllReduce commonly used method is Ring AllReduce, which is divided into two steps: ReduceScatter and AllGather. The communication data volume (send+receive) of each card is approximately 2Φ. + +| zero stage | forward + backward | gradient | optimizer update | communitition | +| --- |--------------------|---------------------|------------------|---------------| +| 0 | NA | AllReduce | NA | 2Φ | +| 1 | NA | 1/N ReduceScatter | 1/N AllGather | 2Φ | +| 2 | NA | 1/N ReduceScatter | 1/N AllGather | 2Φ | +| 3 | 2 AllGather | ReduceScatter | NA | 3Φ | + +It can be concluded that Zero3 has an additional communication calculation. But, computing and communication are parallel streams on MindSpore. When the computation after communication is relatively large, ZeRO3 may be faster. + +## CheckPoint Saving & Loading + +Because the parameters of the model have been split, the parameters of each card need to be saved. + +### Resume + +checkpoint save: + +| zero stage | parameters | optimizer states | ema | +|------------|------------| --- | --- | +| 0 | one card | one card | one card | +| 1 | one card | each card | each card | +| 2 | one card | each card | each card | +| 3 | each card | each card | each card | + +!!! tip + + 💡 Recommend using rank_id to distinguish checkpoint saved on different cards. + +```python +rank_id = get_rank_id() +zero_stage=2 +train_net = prepare_train_network(net, opt, zero_stage=zero_stage, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP) +if resume: + network_ckpt = "network.ckpt" if zero_stage != 3 else f"network_{rank_id}.ckpt" + ms.load_checkpoint(network_ckpt, net=train_net.network) + optimizer_ckpt = "optimizer.ckpt" if zero_stage == 0 else f"optimizer_{rank_id}.ckpt" + ms.load_checkpoint(optimizer_ckpt, net=train_net.optimizer) + ema_ckpt = "ema.ckpt" if zero_stage == 0 else f"ema_{rank_id}.ckpt" + ms.load_checkpoint(ema_ckpt, net=train_net.ema) +``` + +### Inference + +Inference need complete model parameters when use zero3. There are two ways(online & offline) to get the complete model parameters. + +#### Online Checkpoint Combile + +```python +def do_ckpt_combine_online(net_to_save, optimizer_parallel_group): + new_net_to_save = [] + all_gather_op = ops.AllGather(optimizer_parallel_group) + for param in net_to_save: + if param.parallel_optimizer: + new_data = ms.Tensor(all_gather_op(param).asnumpy()) + else: + new_data = ms.Tensor(param.asnumpy()) + new_net_to_save.append({"name": param.name, "data": new_data}) + return new_net_to_save + +net_to_save = [{"name": p.name, "data": p} for p in network.trainable_params()] +net_to_save = net_to_save if zero_stage != 3 else do_ckpt_combine_online(net_to_save, optimizer_parallel_group) +ms.save_checkpoint(net_to_save, "network.ckpt") +``` + +Add the code when need save model parameters. + +#### Offline Checkpoint Combile + +Parameters split infomation will be save when using ZereHelper, could use it to combile the checkpoints offline. + +```python +from mindone.trainers.zero import convert_checkpoints + +src_checkpoint = "save_checkpoint_dir/ckpt_{}.ckpt" +src_param_split_info_json = "params_info/params_split_info_{}.json" +group_size = 2 +convert_checkpoints(src_checkpoint, src_param_split_info_json, group_size) +``` + +And get the complete model parameters checkpoint at `save_checkpoint_dir/ckpt_all_2.ckpt`. diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 1d196db63c..fce13cc621 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -535,7 +535,7 @@ def main(args): latent_diffusion_with_loss, optimizer, zero_stage=args.zero_stage, - op_group=GlobalComm.WORLD_COMM_GROUP, + optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP, comm_fusion=comm_fusion_dict, scale_sense=loss_scaler, drop_overflow_update=args.drop_overflow_update, diff --git a/examples/opensora_pku/tools/ckpt/combine_ckpt.py b/examples/opensora_pku/tools/ckpt/combine_ckpt.py index 57fe2d1fe4..b4110ea86d 100644 --- a/examples/opensora_pku/tools/ckpt/combine_ckpt.py +++ b/examples/opensora_pku/tools/ckpt/combine_ckpt.py @@ -25,7 +25,7 @@ def main(): else args.strategy_ckpt ) assert os.path.exists(strategy_file), f"{strategy_file} does not exist!" - ms.transform_checkpoints(args.src, args.dest, "full_", strategy_file, None) + ms.convert_checkpoints(args.src, args.dest, "full_", strategy_file, None) output_path = os.path.join(args.dest, "rank_0", "full_0.ckpt") assert os.path.isfile(output_path) diff --git a/mindone/diffusers/training_utils.py b/mindone/diffusers/training_utils.py index 9a491a4bff..e823c9867d 100644 --- a/mindone/diffusers/training_utils.py +++ b/mindone/diffusers/training_utils.py @@ -861,7 +861,7 @@ def prepare_train_network( verbose: bool = False, zero_stage: int = 0, optimizer_offload: bool = False, - op_group: str = None, + optimizer_parallel_group: str = None, dp_group: str = None, comm_fusion: dict = None, parallel_modules=None, @@ -878,7 +878,7 @@ def prepare_train_network( the shape should be :math:`()` or :math:`(1,)`. zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0. optimizer_offload (`bool`, *optional*): Only take effect when optimizer is AdamWeightDecay, default is False. - op_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. + optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. dp_group (`str`, *optional*): The name of the data parallel communication group, default is None. comm_fusion (`dict`, *optional*): A dict contains the types and configurations for setting the communication fusion, default is None, turn off the communication fusion. If set a dict, @@ -891,19 +891,23 @@ def prepare_train_network( """ if zero_stage not in [0, 1, 2, 3]: raise ValueError("Not support zero_stage {zero_stage}") - if op_group is None: + if optimizer_parallel_group is None: logger.warning("Not set zero group, set it WORLD_COMM_GROUP.") - op_group = GlobalComm.WORLD_COMM_GROUP - if op_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None: - raise ValueError("op_group {op_group} and dp_group {dp_group} not full network hccl group coverage") + optimizer_parallel_group = GlobalComm.WORLD_COMM_GROUP + if optimizer_parallel_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None: + raise ValueError( + "optimizer_parallel_group {optimizer_parallel_group} and dp_group {dp_group} not full network hccl group coverage" + ) is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL if not is_parallel and zero_stage == 0: logger.info("No need prepare train_network with zero.") zero_helper = None else: - network = prepare_network(network, zero_stage, op_group, parallel_modules=parallel_modules) - zero_helper = ZeroHelper(optimizer, zero_stage, op_group, dp_group, optimizer_offload, comm_fusion) + network = prepare_network(network, zero_stage, optimizer_parallel_group, parallel_modules=parallel_modules) + zero_helper = ZeroHelper( + optimizer, zero_stage, optimizer_parallel_group, dp_group, optimizer_offload, comm_fusion + ) if isinstance(scale_sense, float): scale_sense = ms.Tensor(scale_sense, ms.float32) @@ -931,7 +935,7 @@ def use_zero(self): return self.zero_helper is not None and self.zero_stage != 0 def need_save_optimizer(self, args): - # TODO: Now we save optimizer in every process, try to save depend on self.zero_helper.op_group + # TODO: Now we save optimizer in every process, try to save depend on self.zero_helper.optimizer_parallel_group return True if self.use_zero else is_local_master(args) def save_state(self, args, output_dir, optimizer_state_filter=lambda x: True): diff --git a/mindone/models/modules/parallel/__init__.py b/mindone/models/modules/parallel/__init__.py index 101c1a958a..5240aeb9c6 100644 --- a/mindone/models/modules/parallel/__init__.py +++ b/mindone/models/modules/parallel/__init__.py @@ -1,6 +1,6 @@ from mindspore import mint, nn -from .conv import Conv1d, Conv2d, Conv3d +from .conv import Conv1d, Conv2d, Conv3d, Mint_Conv2d, Mint_Conv3d from .dense import Dense, Linear # {Original MindSpore Cell: New Cell in ZeRO3} @@ -9,7 +9,9 @@ nn.Conv2d: Conv2d, nn.Conv3d: Conv3d, nn.Dense: Dense, + mint.nn.Conv2d: Mint_Conv2d, + mint.nn.Conv3d: Mint_Conv3d, mint.nn.Linear: Linear, } -__all__ = ["Conv1d", "Conv2d", "Conv3d", "Dense", "Linear"] +__all__ = ["Conv1d", "Conv2d", "Conv3d", "Mint_Conv2d", "Mint_Conv3d", "Dense", "Linear"] diff --git a/mindone/models/modules/parallel/conv.py b/mindone/models/modules/parallel/conv.py index e47f958456..b8bd00be84 100644 --- a/mindone/models/modules/parallel/conv.py +++ b/mindone/models/modules/parallel/conv.py @@ -1,4 +1,4 @@ -from mindspore import nn, ops +from mindspore import mint, nn, ops from mindspore.communication import get_group_size, get_rank from mindspore.communication.management import GlobalComm from mindspore.context import ParallelMode @@ -8,25 +8,35 @@ class _Conv(nn.Cell): - def __init__(self, net, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None): + def __init__( + self, net, zero_stage: int = 0, optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None + ): super(_Conv, self).__init__(auto_prefix=False) self.net = net - self.set_param_wrapper(zero_stage, op_group, cell_type) + self.set_param_wrapper(zero_stage, optimizer_parallel_group, cell_type) - def set_param_wrapper(self, zero_stage, op_group, cell_type=None): + @property + def weight(self): + return self.net.weight + + @property + def bias(self): + return self.net.bias + + def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None): self.param_wrapper_w = nn.Identity() self.param_wrapper_b = nn.Identity() if zero_stage == 3: # Init parallel settings is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL - op_group_size = get_group_size(op_group) if is_parallel else 1 - op_rank_id = get_rank(op_group) if is_parallel else 0 - self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, op_group, cell_type) + op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1 + op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0 + self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, optimizer_parallel_group, cell_type) split_op = ops.Split(0, op_group_size) if self.param_wrapper_w.need_rewrite: self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id]) - if self.net.has_bias: - self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type) + if self.net.bias: + self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, optimizer_parallel_group, cell_type) if self.param_wrapper_b.need_rewrite: self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id]) @@ -71,3 +81,45 @@ def construct(self, x): new_shape[1] = self.net.out_channels out = out + bias.reshape(new_shape) return out + + +class Mint_Conv2d(_Conv): + def construct(self, x): + weight = self.param_wrapper_w(self.net.weight) + bias = self.param_wrapper_b(self.net.bias) + if self.net.padding_mode != "zeros": + output = self.net.conv2d( + mint.pad(input, self.net._reversed_padding, mode=self.net.padding_mode), + weight, + bias, + self.net.stride, + (0, 0), + self.net.dilation, + self.net.groups, + ) + else: + output = self.net.conv2d( + input, weight, bias, self.net.stride, self.net.padding, self.net.dilation, self.net.groups + ) + return output + + +class Mint_Conv3d(_Conv): + def construct(self, x): + weight = self.param_wrapper_w(self.net.weight) + bias = self.param_wrapper_b(self.net.bias) + if self.net.padding_mode != "zeros": + output = self.net.conv3d( + mint.pad(input, self.net._reversed_padding, mode=self.net.padding_mode), + weight, + bias, + self.net.stride, + (0, 0, 0), + self.net.dilation, + self.net.groups, + ) + else: + output = self.net.conv3d( + input, weight, bias, self.net.stride, self.net.padding, self.net.dilation, self.net.groups + ) + return output diff --git a/mindone/models/modules/parallel/dense.py b/mindone/models/modules/parallel/dense.py index 8d31690fff..dc24ef0bba 100644 --- a/mindone/models/modules/parallel/dense.py +++ b/mindone/models/modules/parallel/dense.py @@ -16,27 +16,35 @@ def __init__( self, net: Union[nn.Dense, mint.nn.Linear], zero_stage: Literal[0, 1, 2, 3] = 0, - op_group: str = GlobalComm.WORLD_COMM_GROUP, + optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type: Optional[mstype.Type] = None, ): super().__init__(auto_prefix=False) self.net = net - self.set_param_wrapper(zero_stage, op_group, cell_type) + self.set_param_wrapper(zero_stage, optimizer_parallel_group, cell_type) - def set_param_wrapper(self, zero_stage, op_group, cell_type=None): + @property + def weight(self): + return self.net.weight + + @property + def bias(self): + return self.net.bias + + def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None): self.param_wrapper_w = nn.Identity() self.param_wrapper_b = nn.Identity() if zero_stage == 3: # Init parallel settings is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL - op_group_size = get_group_size(op_group) if is_parallel else 1 - op_rank_id = get_rank(op_group) if is_parallel else 0 - self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, op_group, cell_type) + op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1 + op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0 + self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, optimizer_parallel_group, cell_type) split_op = ops.Split(0, op_group_size) if self.param_wrapper_w.need_rewrite: self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id]) if self.net.has_bias: - self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type) + self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, optimizer_parallel_group, cell_type) if self.param_wrapper_b.need_rewrite: self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id]) diff --git a/mindone/models/modules/parallel/param_wrapper.py b/mindone/models/modules/parallel/param_wrapper.py index 1ca8d753b7..b007b97ab7 100644 --- a/mindone/models/modules/parallel/param_wrapper.py +++ b/mindone/models/modules/parallel/param_wrapper.py @@ -12,10 +12,14 @@ class ZeroParamWrapper(nn.Cell): """ def __init__( - self, param: ms.Parameter, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None + self, + param: ms.Parameter, + zero_stage: int = 0, + optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP, + cell_type=None, ): super().__init__(auto_prefix=False) - self.op_group = op_group + self.optimizer_parallel_group = optimizer_parallel_group self.zero_stage = zero_stage self.cell_type = cell_type if zero_stage != 3: @@ -23,16 +27,16 @@ def __init__( # Init parallel settings self.is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL - self.op_group_size = get_group_size(self.op_group) if self.is_parallel else 1 + self.op_group_size = get_group_size(self.optimizer_parallel_group) if self.is_parallel else 1 self.allgather = ops.Identity() self.reduce_scatter = None self.dtype = param.dtype - self.allreduce = ops.AllReduce(group=self.op_group, op=ops.ReduceOp.SUM) + self.allreduce = ops.AllReduce(group=self.optimizer_parallel_group, op=ops.ReduceOp.SUM) self.need_rewrite = self.check_rewrite(param) if self.need_rewrite: - self.op_allgather = ops.AllGather(group=self.op_group) - self.op_reduce_scatter = ops.ReduceScatter(group=self.op_group, op=ops.ReduceOp.SUM) + self.op_allgather = ops.AllGather(group=self.optimizer_parallel_group) + self.op_reduce_scatter = ops.ReduceScatter(group=self.optimizer_parallel_group, op=ops.ReduceOp.SUM) def check_rewrite(self, param): """Check the parameter need to split or not.""" @@ -40,6 +44,7 @@ def check_rewrite(self, param): B = param.shape[0] if not param.parallel_optimizer or B < self.op_group_size or B % self.op_group_size != 0: need_rewrite = False + param.parallel_optimizer = need_rewrite return need_rewrite def construct(self, param): diff --git a/mindone/trainers/callback.py b/mindone/trainers/callback.py index 5e0cc15244..a553342317 100755 --- a/mindone/trainers/callback.py +++ b/mindone/trainers/callback.py @@ -3,8 +3,10 @@ import time from typing import List, Literal, Optional, Tuple, Union -from mindspore import Profiler, Tensor, nn, save_checkpoint +import mindspore as ms +from mindspore import Profiler, Tensor, nn, ops, save_checkpoint from mindspore.communication import get_rank +from mindspore.communication.management import GlobalComm from mindspore.train.callback._callback import Callback, _handle_loss from .checkpoint import CheckpointManager @@ -65,6 +67,9 @@ def __init__( save_training_resume: bool = True, train_steps: int = -1, prefer_low_perf: bool = False, + zero_stage: int = 0, + optimizer_parallel_group: str = None, + ckpt_combine_online: bool = False, ): """ Args: @@ -73,9 +78,29 @@ def __init__( Otherwise, only params that contain one of the keyword in param_save_filter list will be saved. resume_prefix_blacklist: exclude parameters with one of these prefixes to be saved in resume checkpoint, e.g. ('swap.', 'vae.'). + zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0. + optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. + ckpt_combine_online (`bool`, *optional*): combining trainable parameters for saving checkpoint when zero_stage=3, \ + using allgather ops to combile the checkpoint online if `ckpt_combine_online=True`, \ + saving all device parameters if `ckpt_combine_online=False`, \ + and need to use `convert_checkpoints` to combile the checkpoint offline. default is False. """ self.rank_id = rank_id self.is_main_device = rank_id in [0, None] + self.use_zero = zero_stage in [1, 2, 3] + self.ckpt_combine_online = (zero_stage == 3) and ckpt_combine_online + if self.ckpt_combine_online and self.ema is not None: + _logger.warning("Can not enable ckpt_combine_online when use ema, set `ckpt_combine_online=False`.") + self.ckpt_combine_online = False + + self.need_save_network = self.is_main_device or (zero_stage == 3 and not self.ckpt_combine_online) + self.need_save_optimizer = self.is_main_device or self.use_zero + if self.use_zero: + if optimizer_parallel_group is None: + _logger.warning("EvalSaveCallback not set zero group, set it WORLD_COMM_GROUP.") + optimizer_parallel_group = GlobalComm.WORLD_COMM_GROUP + self.optimizer_parallel_group = optimizer_parallel_group + self.op_rank_id = get_rank(optimizer_parallel_group) self.ema = ema if output_dir is not None: self.output_dir = output_dir @@ -97,7 +122,7 @@ def __init__( self.record_lr = record_lr self.save_ema_only = save_ema_only - if self.is_main_device: + if self.need_save_network: self.ckpt_save_policy = ckpt_save_policy self.monitor_metric = monitor_metric self.ckpt_manager = CheckpointManager( @@ -143,6 +168,17 @@ def __init__( resume_prefix_blacklist = (resume_prefix_blacklist,) self.choice_func = lambda x: not x.startswith(resume_prefix_blacklist) + def _do_ckpt_combine_online(self): + new_net_to_save = [] + all_gather_op = ops.AllGather(self.optimizer_parallel_group) + for param in self.net_to_save: + if param.parallel_optimizer: + new_data = ms.Tensor(all_gather_op(param).asnumpy()) + else: + new_data = ms.Tensor(param.asnumpy()) + new_net_to_save.append({"name": param.name, "data": new_data}) + return new_net_to_save + def on_train_step_end(self, run_context): cb_params = run_context.original_args() loss = _handle_loss(cb_params.net_outputs) @@ -158,7 +194,31 @@ def on_train_step_end(self, run_context): else: cur_epoch = cb_params.cur_epoch_num - 1 - if self.is_main_device: + if self.save_training_resume and self.need_save_optimizer: + # TODO: resume training for step. + ckpt_name = f"train_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "train_resume.ckpt" + save_checkpoint( + cb_params.train_network, + os.path.join(self.ckpt_save_dir, ckpt_name), + choice_func=self.choice_func, + append_dict={ + "epoch_num": cur_epoch, + "cur_step": cur_step, + "loss_scale": self._get_scaling_value_from_cbp(cb_params), + }, + ) + if self.ema is not None: + ckpt_name = f"ema_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "ema_resume.ckpt" + save_checkpoint( + self.ema, + os.path.join(self.ckpt_save_dir, ckpt_name), + choice_func=self.choice_func, + ) + + if self.ckpt_combine_online: + new_net_to_save = self._do_ckpt_combine_online() + + if self.need_save_network: # if data sink, train step callback will not be invokded if self.step_mode and (cur_step % self.ckpt_save_interval == 0 or cur_step == step_num): ckpt_name = ( @@ -166,9 +226,13 @@ def on_train_step_end(self, run_context): if self.use_step_unit else f"{self.model_name}-e{cur_epoch}.ckpt" ) + if self.use_zero and not self.ckpt_combine_online: + file_extension = os.path.splitext(ckpt_name) + ckpt_name = f"{file_extension[0]}_op_rank_{self.op_rank_id}{file_extension[1]}" append_dict = {"lora_rank": self.lora_rank} if self.use_lora else None perf = cb_params.get("eval_results") + net_to_save = new_net_to_save if self.ckpt_combine_online else self.net_to_save if perf or self.ckpt_save_policy != "top_k": if perf: perf = perf[self.monitor_metric] @@ -184,20 +248,7 @@ def on_train_step_end(self, run_context): self.ema.swap_before_eval() # save history checkpoints - self.ckpt_manager.save(self.net_to_save, perf, ckpt_name=ckpt_name, append_dict=append_dict) - - if self.save_training_resume: - # TODO: resume training for step. - save_checkpoint( - cb_params.train_network, - os.path.join(self.ckpt_save_dir, "train_resume.ckpt"), - choice_func=self.choice_func, - append_dict={ - "epoch_num": cur_epoch, - "cur_step": cur_step, - "loss_scale": self._get_scaling_value_from_cbp(cb_params), - }, - ) + self.ckpt_manager.save(net_to_save, perf, ckpt_name=ckpt_name, append_dict=append_dict) # swap back network weight and ema weight. MUST execute after model saving and before next-step training if self.ema is not None: @@ -260,15 +311,42 @@ def on_train_epoch_end(self, run_context): opt = self._get_optimizer_from_cbp(cb_params) cur_step = int(opt.global_step.asnumpy().item()) - if self.is_main_device and (not self.step_mode): + if self.save_training_resume and self.need_save_optimizer: + # TODO: resume training for step. + ckpt_name = f"train_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "train_resume.ckpt" + save_checkpoint( + cb_params.train_network, + os.path.join(self.ckpt_save_dir, ckpt_name), + choice_func=self.choice_func, + append_dict={ + "epoch_num": cur_epoch, + "loss_scale": self._get_scaling_value_from_cbp(cb_params), + }, + ) + if self.ema is not None: + ckpt_name = f"ema_resume_op_rank_{self.op_rank_id}.ckpt" if self.use_zero else "ema_resume.ckpt" + save_checkpoint( + self.ema, + os.path.join(self.ckpt_save_dir, ckpt_name), + choice_func=self.choice_func, + ) + + if self.ckpt_combine_online: + new_net_to_save = self._do_ckpt_combine_online() + + if self.need_save_network and (not self.step_mode): if (cur_epoch % self.ckpt_save_interval == 0) or (cur_epoch == epoch_num): ckpt_name = ( f"{self.model_name}-s{cur_step}.ckpt" if self.use_step_unit else f"{self.model_name}-e{cur_epoch}.ckpt" ) + if self.use_zero and not self.ckpt_combine_online: + file_extension = os.path.splitext(ckpt_name) + ckpt_name = f"{file_extension[0]}_op_rank_{self.op_rank_id}{file_extension[1]}" append_dict = {"lora_rank": self.lora_rank} if self.use_lora else None + net_to_save = new_net_to_save if self.ckpt_combine_online else self.net_to_save if self.ema is not None: if not self.save_ema_only: self.ckpt_manager.save( @@ -282,20 +360,9 @@ def on_train_epoch_end(self, run_context): # save history checkpoints self.ckpt_manager.save( - self.net_to_save, perf=cb_params["net_outputs"], ckpt_name=ckpt_name, append_dict=append_dict + net_to_save, perf=cb_params["net_outputs"], ckpt_name=ckpt_name, append_dict=append_dict ) - if self.save_training_resume: - save_checkpoint( - cb_params.train_network, - os.path.join(self.ckpt_save_dir, "train_resume.ckpt"), - choice_func=self.choice_func, - append_dict={ - "epoch_num": cur_epoch, - "loss_scale": self._get_scaling_value_from_cbp(cb_params), - }, - ) - # swap back network weight and ema weight. MUST execute after model saving and before next-step training if self.ema is not None: self.ema.swap_after_eval() diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 42bbc4e326..845d5e28b4 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -71,7 +71,7 @@ def split_np(x, num, idx): return ms.Tensor(x.asnumpy()[start:end]) -@ms.ms_class +@ms.jit_class class ZeroHelper: """ Zero redundancy optimizer(ZeRO) build helper. @@ -85,7 +85,7 @@ class ZeroHelper: Args: optimizer (`nn.Optimizer`): Must be the subclass of MindSpore Optimizer. zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0. - op_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. + optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. dp_group (`str`, *optional*): The name of the data parallel communication group, default is None. optimizer_offload (`bool`, *optional*): Only take effect when optimizer is AdamWeightDecay, default is False. comm_fusion (`dict`, *optional*): A dict contains the types and configurations @@ -102,7 +102,7 @@ def __init__( self, optimizer: nn.Optimizer, zero_stage: int = 0, - op_group: str = None, + optimizer_parallel_group: str = None, dp_group: str = None, optimizer_offload: bool = False, comm_fusion: dict = None, @@ -110,7 +110,7 @@ def __init__( ): self.optimizer = optimizer self.zero_stage = zero_stage - self.op_group = op_group + self.optimizer_parallel_group = optimizer_parallel_group if isinstance(optimizer, ms.experimental.optim.optimizer.Optimizer): self.optimizer._parameters = self.optimizer.parameters self.ori_parameters = self.optimizer._parameters @@ -124,8 +124,8 @@ def __init__( self.op_reduce_scatter = ops.Identity() self.op_allreduce = ops.Identity() self.dp_allreduce = ops.Identity() - self.op_group_size = get_group_size(self.op_group) if self.is_parallel else 1 - self.op_rank_id = get_rank(self.op_group) if self.is_parallel else 0 + self.op_group_size = get_group_size(self.optimizer_parallel_group) if self.is_parallel else 1 + self.op_rank_id = get_rank(self.optimizer_parallel_group) if self.is_parallel else 0 self.need_dp = False self.dp_group = dp_group self.last_assign = False @@ -173,14 +173,14 @@ def __init__( def set_comm_ops( self, ): - self.op_allreduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.op_group) - self.op_reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=self.op_group) + self.op_allreduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.optimizer_parallel_group) + self.op_reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=self.optimizer_parallel_group) # AllGather the parameters after optimizer calculate to update the parameters in train network. - self.op_allgather = ops.AllGather(group=self.op_group) + self.op_allgather = ops.AllGather(group=self.optimizer_parallel_group) self.need_dp = self.dp_group is not None if self.need_dp: - # Set it when op_group is not the WORLD_COMM_GROUP. + # Set it when optimizer_parallel_group is not the WORLD_COMM_GROUP. self.dp_allreduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.dp_group) self.dp_group_size = ms.Tensor(get_group_size(group=self.dp_group), ms.float32) @@ -201,7 +201,7 @@ def set_zero1_allreduce_fusion_comm_list(self, comm_fusion): param_size = param.itemsize * param.size param_name = param.name self.update_comm_op_info(allreduce_info, comm_fusion["allreduce"]["bucket_size"], param_size, param_name) - comm_op = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.op_group) + comm_op = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.optimizer_parallel_group) comm_op.add_prim_attr("fusion", allreduce_info[-1]["fusion_id"]) self.zero1_allreduce_list.append(comm_op) _logger.info(f"zero1_allreduce_fusion: {allreduce_info}") @@ -224,11 +224,11 @@ def set_zero2_reduce_scatter_fusion_comm_list(self, comm_fusion): self.update_comm_op_info( allreduce_info, comm_fusion["allreduce"]["bucket_size"], param_size, param_name ) - comm_op = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=self.op_group) + comm_op = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=self.optimizer_parallel_group) comm_op.add_prim_attr("fusion", reduce_scatter_info[-1]["fusion_id"]) self.zero2_reduce_scatter_list.append(comm_op) - comm_op = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.op_group) + comm_op = ops.AllReduce(op=ops.ReduceOp.SUM, group=self.optimizer_parallel_group) comm_op.add_prim_attr("fusion", allreduce_info[-1]["fusion_id"]) self.zero2_allreduce_list.append(comm_op) _logger.info(f"zero2_reduce_scatter_fusion: {reduce_scatter_info}") @@ -245,7 +245,7 @@ def set_optimizer_allgather_fusion_comm_list(self, comm_fusion): self.update_comm_op_info( allgather_info, comm_fusion["allgather"]["bucket_size"], param_size, param_name ) - comm_op = ops.AllGather(group=self.op_group) + comm_op = ops.AllGather(group=self.optimizer_parallel_group) comm_op.add_prim_attr("fusion", allgather_info[-1]["fusion_id"]) self.optimizer_allgather_list.append(comm_op) _logger.info(f"optimizer_allgather_fusion: {allgather_info}") @@ -261,7 +261,7 @@ def set_dp_allreduce_comm_list(self, comm_fusion): self.update_comm_op_info( dp_allreduce_info, comm_fusion["allreduce"]["bucket_size"], param_size, param_name ) - comm_op = ops.AllGather(group=self.op_group) + comm_op = ops.AllGather(group=self.optimizer_parallel_group) comm_op.add_prim_attr("fusion", dp_allreduce_info[-1]["fusion_id"]) self.dp_allreduce_list.append(comm_op) _logger.info(f"dp_allreduce_fusion: {dp_allreduce_info}") @@ -302,22 +302,20 @@ def dump_params_split_info(self, params_split_info): def get_need_parameter_split(self): self.need_parameter_split = [False] * len(self.optimizer._parameters) - param_tuples = self.get_optimizer_param_tuples() for i, param in enumerate(self.optimizer._parameters): if self.zero_stage == 3: - if param_tuples: - B = param_tuples[0][i].shape[0] - else: - continue + self.need_parameter_split[i] = param.parallel_optimizer else: B = param.shape[0] - if param.parallel_optimizer and B >= self.op_group_size and B % self.op_group_size == 0: - if self.zero_stage in [1, 2]: - self.need_parameter_split[i] = True + if param.parallel_optimizer and B >= self.op_group_size and B % self.op_group_size == 0: + if self.zero_stage in [1, 2]: + self.need_parameter_split[i] = True self.need_parameter_split = tuple(self.need_parameter_split) def split_params(self): - if self.zero_stage in [1, 2] and self.is_parallel: + if not (self.zero_stage in [1, 2, 3] and self.is_parallel): + return + if self.zero_stage in [1, 2]: _logger.info("Clone optimizer.parameters, will increase memory.") # Because the first input of MindSpore optimizer must be ms.Parameter, # copy optimizer.parameters for optimizer parameters update. @@ -331,15 +329,14 @@ def split_params(self): _logger.debug(f"Split optimizer param {param.name} {param.shape}") # If zero_stage is 3, the parameters in train network have been split, # use parameter in param_tuples to get batch size. - if self.zero_stage == 3: - if param_tuples: - B = param_tuples[0][i].shape[0] - else: - continue - else: - B = param.shape[0] _logger.debug(f"Do split with zero_stage {self.zero_stage}") - if param.parallel_optimizer and B >= self.op_group_size and B % self.op_group_size == 0: + if self.zero_stage in [1, 2]: + B = param.shape[0] + if self.ori_parameters[i] and B >= self.op_group_size and B % self.op_group_size == 0: + param.parallel_optimizer = True + else: + param.parallel_optimizer = False + if param.parallel_optimizer: if self.zero_stage in [1, 2]: ori_shape = param.shape param.assign_value(self.split_param(param)) @@ -347,7 +344,7 @@ def split_params(self): for param_tuple in param_tuples: ori_shape = param_tuple[i].shape param_tuple[i].assign_value(self.split_param(param_tuple[i])) - _logger.debug(f"Optimizer {param_tuple[i].name} " f"from {ori_shape} to {param_tuple[i].shape}") + _logger.debug(f"Optimizer {param_tuple[i].name} from {ori_shape} to {param_tuple[i].shape}") def reduce_scatter_gradients(self, gradients): dtype = gradients[0].dtype @@ -470,11 +467,11 @@ def get_cell_dtype(cell): return None -def _init_parallel_settings(net, op_group, parallel_modules=None): +def _init_parallel_settings(net, optimizer_parallel_group, parallel_modules=None): for module, parallel_module in parallel_modules.items(): if isinstance(net, module): cell_type = get_cell_dtype(net) - new_net = parallel_module(net, 3, op_group) + new_net = parallel_module(net, 3, optimizer_parallel_group) if cell_type is not None: new_net.to_float(cell_type) return new_net @@ -488,14 +485,14 @@ def get_cell_params_fullname_dict(cell: nn.Cell): return fullname_dict -def _prepare_network(network: nn.Cell, op_group: str, parallel_modules=None): - new_net = _init_parallel_settings(network, op_group, parallel_modules) +def _prepare_network(network: nn.Cell, optimizer_parallel_group: str, parallel_modules=None): + new_net = _init_parallel_settings(network, optimizer_parallel_group, parallel_modules) if new_net is not None: return new_net for name, sub_net in network._cells.items(): if not sub_net: continue - new_sub_net = _init_parallel_settings(sub_net, op_group, parallel_modules) + new_sub_net = _init_parallel_settings(sub_net, optimizer_parallel_group, parallel_modules) if new_sub_net is not None: params_fullname_dict = get_cell_params_fullname_dict(sub_net) if isinstance(network, (nn.CellList, nn.SequentialCell)): @@ -514,27 +511,27 @@ def _prepare_network(network: nn.Cell, op_group: str, parallel_modules=None): param = getattr(sub_net, param_name) _logger.warning(f"Set param {param.name} parallel_optimizer False, param shape {param.shape}") param.parallel_optimizer = False - _prepare_network(sub_net, op_group, parallel_modules) + _prepare_network(sub_net, optimizer_parallel_group, parallel_modules) return network -def prepare_network(network: nn.Cell, zero_stage: int = 0, op_group: str = None, parallel_modules=None): +def prepare_network(network: nn.Cell, zero_stage: int = 0, optimizer_parallel_group: str = None, parallel_modules=None): if zero_stage != 3 or _get_parallel_mode() != ParallelMode.DATA_PARALLEL: _logger.info("No need rewrite network and return original network.") return network _logger.info("Rewrite the network, please wait...") if parallel_modules is None: parallel_modules = PARALLEL_MODULES - network = _prepare_network(network, op_group, parallel_modules) + network = _prepare_network(network, optimizer_parallel_group, parallel_modules) return network -def prepare_ema(ema, zero_stage: int = 0, op_group: str = None): +def prepare_ema(ema, zero_stage: int = 0, optimizer_parallel_group: str = None): is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL if not is_parallel or zero_stage != 3: return ema - op_group_size = get_group_size(op_group) - op_rank_id = get_rank(op_group) + op_group_size = get_group_size(optimizer_parallel_group) + op_rank_id = get_rank(optimizer_parallel_group) _logger.info(f"Split EMA params: rank_id {op_rank_id}, rank_size {op_group_size}.") for net_weight, ema_weight, swap_cache in zip(ema.net_weight, ema.ema_weight, ema.swap_cache): if net_weight.shape == ema_weight.shape: @@ -557,7 +554,7 @@ def prepare_train_network( verbose: bool = False, zero_stage: Literal[0, 1, 2, 3] = 0, optimizer_offload: bool = False, - op_group: str = None, + optimizer_parallel_group: str = None, dp_group: str = None, comm_fusion: dict = None, parallel_modules=None, @@ -574,7 +571,7 @@ def prepare_train_network( the shape should be :math:`()` or :math:`(1,)`. zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0. optimizer_offload (`bool`, *optional*): Only take effect when optimizer is AdamWeightDecay, default is False. - op_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. + optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None. dp_group (`str`, *optional*): The name of the data parallel communication group, default is None. comm_fusion (`dict`, *optional*): A dict contains the types and configurations for setting the communication fusion, default is None, turn off the communication fusion. If set a dict, @@ -587,22 +584,26 @@ def prepare_train_network( """ if zero_stage not in [0, 1, 2, 3]: raise ValueError("Not support zero_stage {zero_stage}") - if op_group is None: + if optimizer_parallel_group is None: _logger.warning("Not set zero group, set it WORLD_COMM_GROUP.") - op_group = GlobalComm.WORLD_COMM_GROUP - if op_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None: - raise ValueError("op_group {op_group} and dp_group {dp_group} not full network hccl group coverage") + optimizer_parallel_group = GlobalComm.WORLD_COMM_GROUP + if optimizer_parallel_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None: + raise ValueError( + "optimizer_parallel_group {optimizer_parallel_group} and dp_group {dp_group} not full network hccl group coverage" + ) is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL if not is_parallel and zero_stage == 0: _logger.info("No need prepare train_network with zero.") zero_helper = None else: - network = prepare_network(network, zero_stage, op_group, parallel_modules=parallel_modules) - zero_helper = ZeroHelper(optimizer, zero_stage, op_group, dp_group, optimizer_offload, comm_fusion) + network = prepare_network(network, zero_stage, optimizer_parallel_group, parallel_modules=parallel_modules) + zero_helper = ZeroHelper( + optimizer, zero_stage, optimizer_parallel_group, dp_group, optimizer_offload, comm_fusion + ) if ema is not None: - ema = prepare_ema(ema, zero_stage, op_group) + ema = prepare_ema(ema, zero_stage, optimizer_parallel_group) if isinstance(scale_sense, float): scale_sense = ms.Tensor(scale_sense, ms.float32) train_network = TrainOneStepWrapper( @@ -621,7 +622,7 @@ def prepare_train_network( return train_network -def transform_checkpoints(src_checkpoint: str, src_param_split_info_json: str, group_size: int): +def convert_checkpoints(src_checkpoint: str, src_param_split_info_json: str, group_size: int): """ src_checkpoint (`str`): The path of checkpoints need to merge parameters. eg. "save_checkpoint_dir/ckpt_{}.ckpt", {} is placeholder of rank_id. diff --git a/tests/others/test_zero.py b/tests/others/test_zero.py index c9c99742c0..5d3a613c7f 100644 --- a/tests/others/test_zero.py +++ b/tests/others/test_zero.py @@ -78,7 +78,12 @@ def test_zero(x, y, zero_stage=0, comm_fusion=False): "allgather": {"bucket_size": 64}, } train_net = prepare_train_network( - net, opt, ema=ema, zero_stage=zero_stage, op_group=GlobalComm.WORLD_COMM_GROUP, comm_fusion=comm_fusion_dict + net, + opt, + ema=ema, + zero_stage=zero_stage, + optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP, + comm_fusion=comm_fusion_dict, ) for i in range(10):