Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zero doc & checkpoint save adapt #834

Merged
merged 9 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ hide:

[:octicons-arrow-right-24: Start tuning!](peft/index.md)

- :star2: __Tools__

---

Train Tools. Include Trainer, ZeRO, Image/Vedio data filtering strategy...

[:octicons-arrow-right-24: Using it!](tools/zero.md)

- > :rocket: __Accelerate__
> ---
Expand Down
4 changes: 4 additions & 0 deletions docs/tools/_toctree.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- sections:
- local: zero
title: ZeRO
title: Get started
155 changes: 155 additions & 0 deletions docs/tools/zero.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Zero redundancy optimizer(ZeRO) on MindOne

Zero Redundancy Optimizer (ZeRO) is a method for reducing memory usage under data parallelism strategy on paper: [ZeRO: ZeRO: Memory Optimization Towards Training A Trillion Parameter Models](https://arxiv.org/pdf/1910.02054.pdf).

ZeRO eliminates memory redundancies in data and model parallel training while retaining low communication volume and high computational
granularity, allowing us to scale the model size proportional to the number of devices with sustained high efficiency.

This tutorial walks you through how to generate faster and better with the ZeRO on MindOne.

## Build Train Network With ZeRO

Build a train network with ZeRO.

```python
import mindspore as ms
from mindspore.communication import init
from mindspore.communication.management import GlobalComm
from mindone.trainers.zero import prepare_train_network

# Initialize distributed environment
def init_env(mode, distribute):
ms.set_context(mode=mode)
if distribute:
init()
# ZeRO take effect must on DATA_PARALLEL
ms.set_auto_parallel_context(
parallel_mode=ms.ParallelMode.DATA_PARALLEL,
gradients_mean=True,
)

init_env(ms.GRAPH_MODE, True)

# Net is your Train Network
net = Net()
# opt must be the subclass of MindSpore Optimizer.
opt = nn.AdamWeightDecay(net.trainable_params(), learning_rate=1e-3)

# build a train network with ZeRO
train_net = prepare_train_network(net, opt, zero_stage=2, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
```

!!! tip
optimizer_parallel_group may not be GlobalComm.WORLD_COMM_GROUP. Using [create_group](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.communication.html#mindspore.communication.create_group) to create your optimizer_parallel_group.

More details:

::: mindone.trainers.zero.prepare_train_network

[Here](https://github.com/mindspore-lab/mindone/blob/master/tests/others/test_zero.py) is an example.

## Memory Analysis

The memory consumption during the training can be divided into two main parts:

- Residual states. Mainly includes activate functions, temporary buffers, and unavailable memory fragments.
- Model states. Mainly includes three parts: optimizer states(AdamW fp32), gradients(fp16), and parameters(fp16). The three are abbreviated as OPG. Assuming the number of model parameters is Φ,
the total model states is 2Φ(parameters) + 2Φ(gradients) + (4Φ + 4Φ + 4Φ)(optimizer states) = 16Φ, the AdamW states accounting for 75%.

Residual states can be greatly reduced through [recompute](https://www.mindspore.cn/docs/en/master/model_train/parallel/recompute.html) and [model parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/strategy_select.html).
Then the ZeRO algorithm can be used to reduce model states.

For the optimization of model states (removing redundancy), ZeRO uses the method of partitioning, which means that each card only stores 1/N data.

ZeRO has three main optimization stages (as depicted in ZeRO paper Figure 1), which correspond to the partitioning of optimizer states, gradients, and parameters. When enabled cumulatively:

1) Optimizer State Partitioning (Pos): Optimizer states are kept 1/N, the model parameters and gradients are still kept in full on each card. The model state of each card is 4Φ + 12Φ/N, when N is very large, it tend to 4Φ, that's the 1/4 original memory;
2) Add Gradient Partitioning (Pos+g): Add the gradients partitioning to 1/N, The model state of each card is 2Φ + (2Φ + 12Φ)/N, when N is very large, it tend to 2Φ, that's the 1/8 original memory;
3) Add Parameter Partitioning (Pos+g+p): Add the parameters partitioning to 1/N, The model state of each card is 16Φ/N, when N is very large, it tend to 0;

Pos correspond to ZeRO-1, Pos+g correspond to ZeRO-2 and Pos+g+p correspond to ZeRO-3.

## Communitition Analysis

Currently, AllReduce commonly used method is Ring AllReduce, which is divided into two steps: ReduceScatter and AllGather. The communication data volume (send+receive) of each card is approximately 2Φ.

| zero stage | forward + backward | gradient | optimizer update | communitition |
| --- |--------------------|---------------------|------------------|---------------|
| 0 | NA | AllReduce | NA | 2Φ |
| 1 | NA | 1/N ReduceScatter | 1/N AllGather | 2Φ |
| 2 | NA | 1/N ReduceScatter | 1/N AllGather | 2Φ |
| 3 | 2 AllGather | ReduceScatter | NA | 3Φ |

It can be concluded that Zero3 has an additional communication calculation. But, computing and communication are parallel streams on MindSpore. When the computation after communication is relatively large, ZeRO3 may be faster.

## CheckPoint Saving & Loading

Because the parameters of the model have been split, the parameters of each card need to be saved.

### Resume

checkpoint save:

| zero stage | parameters | optimizer states | ema |
|------------|------------| --- | --- |
| 0 | one card | one card | one card |
| 1 | one card | each card | each card |
| 2 | one card | each card | each card |
| 3 | each card | each card | each card |

!!! tip

💡 Recommend using rank_id to distinguish checkpoint saved on different cards.

```python
rank_id = get_rank_id()
zero_stage=2
train_net = prepare_train_network(net, opt, zero_stage=zero_stage, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
if resume:
network_ckpt = "network.ckpt" if zero_stage != 3 else f"network_{rank_id}.ckpt"
ms.load_checkpoint(network_ckpt, net=train_net.network)
optimizer_ckpt = "optimizer.ckpt" if zero_stage == 0 else f"optimizer_{rank_id}.ckpt"
ms.load_checkpoint(optimizer_ckpt, net=train_net.optimizer)
ema_ckpt = "ema.ckpt" if zero_stage == 0 else f"ema_{rank_id}.ckpt"
ms.load_checkpoint(ema_ckpt, net=train_net.ema)
```

### Inference

Inference need complete model parameters when use zero3. There are two ways(online & offline) to get the complete model parameters.

#### Online Checkpoint Combile

```python
def do_ckpt_combine_online(net_to_save, optimizer_parallel_group):
new_net_to_save = []
all_gather_op = ops.AllGather(optimizer_parallel_group)
for param in net_to_save:
if param.parallel_optimizer:
new_data = ms.Tensor(all_gather_op(param).asnumpy())
else:
new_data = ms.Tensor(param.asnumpy())
new_net_to_save.append({"name": param.name, "data": new_data})
return new_net_to_save

net_to_save = [{"name": p.name, "data": p} for p in network.trainable_params()]
net_to_save = net_to_save if zero_stage != 3 else do_ckpt_combine_online(net_to_save, optimizer_parallel_group)
ms.save_checkpoint(net_to_save, "network.ckpt")
```

Add the code when need save model parameters.

#### Offline Checkpoint Combile

Parameters split infomation will be save when using ZereHelper, could use it to combile the checkpoints offline.

```python
from mindone.trainers.zero import convert_checkpoints

src_checkpoint = "save_checkpoint_dir/ckpt_{}.ckpt"
src_param_split_info_json = "params_info/params_split_info_{}.json"
group_size = 2
convert_checkpoints(src_checkpoint, src_param_split_info_json, group_size)
```

And get the complete model parameters checkpoint at `save_checkpoint_dir/ckpt_all_2.ckpt`.
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def main(args):
latent_diffusion_with_loss,
optimizer,
zero_stage=args.zero_stage,
op_group=GlobalComm.WORLD_COMM_GROUP,
optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP,
comm_fusion=comm_fusion_dict,
scale_sense=loss_scaler,
drop_overflow_update=args.drop_overflow_update,
Expand Down
2 changes: 1 addition & 1 deletion examples/opensora_pku/tools/ckpt/combine_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main():
else args.strategy_ckpt
)
assert os.path.exists(strategy_file), f"{strategy_file} does not exist!"
ms.transform_checkpoints(args.src, args.dest, "full_", strategy_file, None)
ms.convert_checkpoints(args.src, args.dest, "full_", strategy_file, None)

output_path = os.path.join(args.dest, "rank_0", "full_0.ckpt")
assert os.path.isfile(output_path)
Expand Down
22 changes: 13 additions & 9 deletions mindone/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ def prepare_train_network(
verbose: bool = False,
zero_stage: int = 0,
optimizer_offload: bool = False,
op_group: str = None,
optimizer_parallel_group: str = None,
dp_group: str = None,
comm_fusion: dict = None,
parallel_modules=None,
Expand All @@ -878,7 +878,7 @@ def prepare_train_network(
the shape should be :math:`()` or :math:`(1,)`.
zero_stage (`int`, *optional*): Stage setting of ZeRO, default is 0.
optimizer_offload (`bool`, *optional*): Only take effect when optimizer is AdamWeightDecay, default is False.
op_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None.
optimizer_parallel_group (`str`, *optional*): The name of the optimizer parallel communication group, default is None.
dp_group (`str`, *optional*): The name of the data parallel communication group, default is None.
comm_fusion (`dict`, *optional*): A dict contains the types and configurations
for setting the communication fusion, default is None, turn off the communication fusion. If set a dict,
Expand All @@ -891,19 +891,23 @@ def prepare_train_network(
"""
if zero_stage not in [0, 1, 2, 3]:
raise ValueError("Not support zero_stage {zero_stage}")
if op_group is None:
if optimizer_parallel_group is None:
logger.warning("Not set zero group, set it WORLD_COMM_GROUP.")
op_group = GlobalComm.WORLD_COMM_GROUP
if op_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None:
raise ValueError("op_group {op_group} and dp_group {dp_group} not full network hccl group coverage")
optimizer_parallel_group = GlobalComm.WORLD_COMM_GROUP
if optimizer_parallel_group != GlobalComm.WORLD_COMM_GROUP and dp_group is None:
raise ValueError(
"optimizer_parallel_group {optimizer_parallel_group} and dp_group {dp_group} not full network hccl group coverage"
)

is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL
if not is_parallel and zero_stage == 0:
logger.info("No need prepare train_network with zero.")
zero_helper = None
else:
network = prepare_network(network, zero_stage, op_group, parallel_modules=parallel_modules)
zero_helper = ZeroHelper(optimizer, zero_stage, op_group, dp_group, optimizer_offload, comm_fusion)
network = prepare_network(network, zero_stage, optimizer_parallel_group, parallel_modules=parallel_modules)
zero_helper = ZeroHelper(
optimizer, zero_stage, optimizer_parallel_group, dp_group, optimizer_offload, comm_fusion
)

if isinstance(scale_sense, float):
scale_sense = ms.Tensor(scale_sense, ms.float32)
Expand Down Expand Up @@ -931,7 +935,7 @@ def use_zero(self):
return self.zero_helper is not None and self.zero_stage != 0

def need_save_optimizer(self, args):
# TODO: Now we save optimizer in every process, try to save depend on self.zero_helper.op_group
# TODO: Now we save optimizer in every process, try to save depend on self.zero_helper.optimizer_parallel_group
return True if self.use_zero else is_local_master(args)

def save_state(self, args, output_dir, optimizer_state_filter=lambda x: True):
Expand Down
6 changes: 4 additions & 2 deletions mindone/models/modules/parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from mindspore import mint, nn

from .conv import Conv1d, Conv2d, Conv3d
from .conv import Conv1d, Conv2d, Conv3d, Mint_Conv2d, Mint_Conv3d
from .dense import Dense, Linear

# {Original MindSpore Cell: New Cell in ZeRO3}
Expand All @@ -9,7 +9,9 @@
nn.Conv2d: Conv2d,
nn.Conv3d: Conv3d,
nn.Dense: Dense,
mint.nn.Conv2d: Mint_Conv2d,
mint.nn.Conv3d: Mint_Conv3d,
mint.nn.Linear: Linear,
}

__all__ = ["Conv1d", "Conv2d", "Conv3d", "Dense", "Linear"]
__all__ = ["Conv1d", "Conv2d", "Conv3d", "Mint_Conv2d", "Mint_Conv3d", "Dense", "Linear"]
70 changes: 61 additions & 9 deletions mindone/models/modules/parallel/conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mindspore import nn, ops
from mindspore import mint, nn, ops
from mindspore.communication import get_group_size, get_rank
from mindspore.communication.management import GlobalComm
from mindspore.context import ParallelMode
Expand All @@ -8,25 +8,35 @@


class _Conv(nn.Cell):
def __init__(self, net, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None):
def __init__(
self, net, zero_stage: int = 0, optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None
):
super(_Conv, self).__init__(auto_prefix=False)
self.net = net
self.set_param_wrapper(zero_stage, op_group, cell_type)
self.set_param_wrapper(zero_stage, optimizer_parallel_group, cell_type)

def set_param_wrapper(self, zero_stage, op_group, cell_type=None):
@property
def weight(self):
return self.net.weight

@property
def bias(self):
return self.net.bias

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may change if self.net.has_bias -> if self.bias is not None since mint.Conv does not have has_bias attribute

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None):
self.param_wrapper_w = nn.Identity()
self.param_wrapper_b = nn.Identity()
if zero_stage == 3:
# Init parallel settings
is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL
op_group_size = get_group_size(op_group) if is_parallel else 1
op_rank_id = get_rank(op_group) if is_parallel else 0
self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, op_group, cell_type)
op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1
op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0
self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, optimizer_parallel_group, cell_type)
split_op = ops.Split(0, op_group_size)
if self.param_wrapper_w.need_rewrite:
self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id])
if self.net.has_bias:
self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type)
if self.net.bias:
self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, optimizer_parallel_group, cell_type)
if self.param_wrapper_b.need_rewrite:
self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id])

Expand Down Expand Up @@ -71,3 +81,45 @@ def construct(self, x):
new_shape[1] = self.net.out_channels
out = out + bias.reshape(new_shape)
return out


class Mint_Conv2d(_Conv):
def construct(self, x):
weight = self.param_wrapper_w(self.net.weight)
bias = self.param_wrapper_b(self.net.bias)
if self.net.padding_mode != "zeros":
output = self.net.conv2d(
mint.pad(input, self.net._reversed_padding, mode=self.net.padding_mode),
weight,
bias,
self.net.stride,
(0, 0),
self.net.dilation,
self.net.groups,
)
else:
output = self.net.conv2d(
input, weight, bias, self.net.stride, self.net.padding, self.net.dilation, self.net.groups
)
return output


class Mint_Conv3d(_Conv):
def construct(self, x):
weight = self.param_wrapper_w(self.net.weight)
bias = self.param_wrapper_b(self.net.bias)
if self.net.padding_mode != "zeros":
output = self.net.conv3d(
mint.pad(input, self.net._reversed_padding, mode=self.net.padding_mode),
weight,
bias,
self.net.stride,
(0, 0, 0),
self.net.dilation,
self.net.groups,
)
else:
output = self.net.conv3d(
input, weight, bias, self.net.stride, self.net.padding, self.net.dilation, self.net.groups
)
return output
Loading