Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add cogvideox SP #838

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ hide:

[:octicons-arrow-right-24: Start tuning!](peft/index.md)

- :star2: __Tools__

---

Train Tools. Include Trainer, ZeRO, Image/Vedio data filtering strategy...

[:octicons-arrow-right-24: Using it!](tools/zero.md)

- > :rocket: __Accelerate__
> ---
Expand Down
4 changes: 4 additions & 0 deletions docs/tools/_toctree.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- sections:
- local: zero
title: ZeRO
title: Get started
155 changes: 155 additions & 0 deletions docs/tools/zero.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Zero redundancy optimizer(ZeRO) on MindOne

Zero Redundancy Optimizer (ZeRO) is a method for reducing memory usage under data parallelism strategy on paper: [ZeRO: ZeRO: Memory Optimization Towards Training A Trillion Parameter Models](https://arxiv.org/pdf/1910.02054.pdf).

ZeRO eliminates memory redundancies in data and model parallel training while retaining low communication volume and high computational
granularity, allowing us to scale the model size proportional to the number of devices with sustained high efficiency.

This tutorial walks you through how to generate faster and better with the ZeRO on MindOne.

## Build Train Network With ZeRO

Build a train network with ZeRO.

```python
import mindspore as ms
from mindspore.communication import init
from mindspore.communication.management import GlobalComm
from mindone.trainers.zero import prepare_train_network

# Initialize distributed environment
def init_env(mode, distribute):
ms.set_context(mode=mode)
if distribute:
init()
# ZeRO take effect must on DATA_PARALLEL
ms.set_auto_parallel_context(
parallel_mode=ms.ParallelMode.DATA_PARALLEL,
gradients_mean=True,
)

init_env(ms.GRAPH_MODE, True)

# Net is your Train Network
net = Net()
# opt must be the subclass of MindSpore Optimizer.
opt = nn.AdamWeightDecay(net.trainable_params(), learning_rate=1e-3)

# build a train network with ZeRO
train_net = prepare_train_network(net, opt, zero_stage=2, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
```

!!! tip
optimizer_parallel_group may not be GlobalComm.WORLD_COMM_GROUP. Using [create_group](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.communication.html#mindspore.communication.create_group) to create your optimizer_parallel_group.

More details:

::: mindone.trainers.zero.prepare_train_network

[Here](https://github.com/mindspore-lab/mindone/blob/master/tests/others/test_zero.py) is an example.

## Memory Analysis

The memory consumption during the training can be divided into two main parts:

- Residual states. Mainly includes activate functions, temporary buffers, and unavailable memory fragments.
- Model states. Mainly includes three parts: optimizer states(AdamW fp32), gradients(fp16), and parameters(fp16). The three are abbreviated as OPG. Assuming the number of model parameters is Φ,
the total model states is 2Φ(parameters) + 2Φ(gradients) + (4Φ + 4Φ + 4Φ)(optimizer states) = 16Φ, the AdamW states accounting for 75%.

Residual states can be greatly reduced through [recompute](https://www.mindspore.cn/docs/en/master/model_train/parallel/recompute.html) and [model parallel](https://www.mindspore.cn/docs/en/master/model_train/parallel/strategy_select.html).
Then the ZeRO algorithm can be used to reduce model states.

For the optimization of model states (removing redundancy), ZeRO uses the method of partitioning, which means that each card only stores 1/N data.

ZeRO has three main optimization stages (as depicted in ZeRO paper Figure 1), which correspond to the partitioning of optimizer states, gradients, and parameters. When enabled cumulatively:

1) Optimizer State Partitioning (Pos): Optimizer states are kept 1/N, the model parameters and gradients are still kept in full on each card. The model state of each card is 4Φ + 12Φ/N, when N is very large, it tend to 4Φ, that's the 1/4 original memory;
2) Add Gradient Partitioning (Pos+g): Add the gradients partitioning to 1/N, The model state of each card is 2Φ + (2Φ + 12Φ)/N, when N is very large, it tend to 2Φ, that's the 1/8 original memory;
3) Add Parameter Partitioning (Pos+g+p): Add the parameters partitioning to 1/N, The model state of each card is 16Φ/N, when N is very large, it tend to 0;

Pos correspond to ZeRO-1, Pos+g correspond to ZeRO-2 and Pos+g+p correspond to ZeRO-3.

## Communitition Analysis

Currently, AllReduce commonly used method is Ring AllReduce, which is divided into two steps: ReduceScatter and AllGather. The communication data volume (send+receive) of each card is approximately 2Φ.

| zero stage | forward + backward | gradient | optimizer update | communitition |
| --- |--------------------|---------------------|------------------|---------------|
| 0 | NA | AllReduce | NA | 2Φ |
| 1 | NA | 1/N ReduceScatter | 1/N AllGather | 2Φ |
| 2 | NA | 1/N ReduceScatter | 1/N AllGather | 2Φ |
| 3 | 2 AllGather | ReduceScatter | NA | 3Φ |

It can be concluded that Zero3 has an additional communication calculation. But, computing and communication are parallel streams on MindSpore. When the computation after communication is relatively large, ZeRO3 may be faster.

## CheckPoint Saving & Loading

Because the parameters of the model have been split, the parameters of each card need to be saved.

### Resume

checkpoint save:

| zero stage | parameters | optimizer states | ema |
|------------|------------| --- | --- |
| 0 | one card | one card | one card |
| 1 | one card | each card | each card |
| 2 | one card | each card | each card |
| 3 | each card | each card | each card |

!!! tip

💡 Recommend using rank_id to distinguish checkpoint saved on different cards.

```python
rank_id = get_rank_id()
zero_stage=2
train_net = prepare_train_network(net, opt, zero_stage=zero_stage, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP)
if resume:
network_ckpt = "network.ckpt" if zero_stage != 3 else f"network_{rank_id}.ckpt"
ms.load_checkpoint(network_ckpt, net=train_net.network)
optimizer_ckpt = "optimizer.ckpt" if zero_stage == 0 else f"optimizer_{rank_id}.ckpt"
ms.load_checkpoint(optimizer_ckpt, net=train_net.optimizer)
ema_ckpt = "ema.ckpt" if zero_stage == 0 else f"ema_{rank_id}.ckpt"
ms.load_checkpoint(ema_ckpt, net=train_net.ema)
```

### Inference

Inference need complete model parameters when use zero3. There are two ways(online & offline) to get the complete model parameters.

#### Online Checkpoint Combile

```python
def do_ckpt_combine_online(net_to_save, optimizer_parallel_group):
new_net_to_save = []
all_gather_op = ops.AllGather(optimizer_parallel_group)
for param in net_to_save:
if param.parallel_optimizer:
new_data = ms.Tensor(all_gather_op(param).asnumpy())
else:
new_data = ms.Tensor(param.asnumpy())
new_net_to_save.append({"name": param.name, "data": new_data})
return new_net_to_save

net_to_save = [{"name": p.name, "data": p} for p in network.trainable_params()]
net_to_save = net_to_save if zero_stage != 3 else do_ckpt_combine_online(net_to_save, optimizer_parallel_group)
ms.save_checkpoint(net_to_save, "network.ckpt")
```

Add the code when need save model parameters.

#### Offline Checkpoint Combile

Parameters split infomation will be save when using ZereHelper, could use it to combile the checkpoints offline.

```python
from mindone.trainers.zero import convert_checkpoints

src_checkpoint = "save_checkpoint_dir/ckpt_{}.ckpt"
src_param_split_info_json = "params_info/params_split_info_{}.json"
group_size = 2
convert_checkpoints(src_checkpoint, src_param_split_info_json, group_size)
```

And get the complete model parameters checkpoint at `save_checkpoint_dir/ckpt_all_2.ckpt`.
5 changes: 2 additions & 3 deletions examples/diffusers/cogvideox_factory/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ export_to_video(video, "output.mp4", fps=8)
> [!TIP]
> 由于模型和框架的限制,对于训练我们暂时推荐分阶段的训练流程,即先通过[`prepare_dateset.sh`](./prepare_dataset.sh)预处理数据集,然后读取预处理后的数据集通过`train*.sh`进行正式训练。
>
> 在正式训练阶段,需要增加`--load_tensors`参数以支持预处理数据集。建议增加参数`--mindspore_mode=0`以进行静态图训练加速,在`train*.sh`里可通过设置参数`MINDSPORE_MODE=0`实现。
> 在正式训练阶段,需要增加`--embeddings_cache`参数以支持text embeddings预处理。建议增加参数`--mindspore_mode=0`以进行静态图训练加速,在`train*.sh`里可通过设置参数`MINDSPORE_MODE=0`实现。
>
> 具体情况参见[与原仓的差异 & 功能限制](#与原仓的差异功能限制)

Expand Down Expand Up @@ -202,7 +202,6 @@ DTYPE=bf16
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--load_tensors \
--optimizer $optimizer \
--beta1 0.9 \
--beta2 0.95 \
Expand All @@ -211,7 +210,7 @@ DTYPE=bf16
--report_to tensorboard \
--mindspore_mode $MINDSPORE_MODE \
--amp_level $AMP_LEVEL \
--load_tensors \
--embeddings_cache \
$EXTRA_ARGS"

echo "Running command: $cmd"
Expand Down
2 changes: 1 addition & 1 deletion examples/diffusers/cogvideox_factory/assets/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-D

This dataset has been prepared in the expected format and can be used directly. However, directly using the video dataset may cause Out of Memory (OOM) issues on GPUs with smaller VRAM because it requires loading the [VAE](https://huggingface.co/THUDM/CogVideoX-5b/tree/main/vae) (which encodes videos into latent space) and the large [T5-XXL](https://huggingface.co/google/t5-v1_1-xxl/) text encoder. To reduce memory usage, you can use the `training/prepare_dataset.py` script to precompute latents and embeddings.

Fill or modify the parameters in `prepare_dataset.sh` and execute it to get precomputed latents and embeddings (make sure to specify `--save_latents_and_embeddings` to save the precomputed artifacts). If preparing for image-to-video training, make sure to pass `--save_image_latents`, which encodes and saves image latents along with videos. When using these artifacts during training, ensure that you specify the `--load_tensors` flag, or else the videos will be used directly, requiring the text encoder and VAE to be loaded. The script also supports PyTorch DDP so that large datasets can be encoded in parallel across multiple GPUs (modify the `NUM_GPUS` parameter).
Fill or modify the parameters in `prepare_dataset.sh` and execute it to get precomputed latents and embeddings (make sure to specify `--save_latents_and_embeddings` to save the precomputed artifacts). If preparing for image-to-video training, make sure to pass `--save_image_latents`, which encodes and saves image latents along with videos. When using these artifacts during training, ensure that you specify the `--embeddings_cache` and `--latents_cache` flag, or else the videos will be used directly, requiring the text encoder and VAE to be loaded. The script also supports PyTorch DDP so that large datasets can be encoded in parallel across multiple GPUs (modify the `NUM_GPUS` parameter).
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ OOM(内存不足),因为它需要加载 [VAE](https://huggingface.co/THUDM

填写或修改 `prepare_dataset.sh` 中的参数并执行它以获得预先计算的潜在变量和嵌入(请确保指定 `--save_latents_and_embeddings`
以保存预计算的工件)。如果准备图像到视频的训练,请确保传递 `--save_image_latents`,它对沙子进行编码,将图像潜在值与视频一起保存。
在训练期间使用这些工件时,确保指定 `--load_tensors` 标志,否则将直接使用视频并需要加载文本编码器和
在训练期间使用这些工件时,确保指定 `--embeddings_cache`和`--latents_cache` 标志,否则将直接使用视频并需要加载文本编码器和
VAE。该脚本还支持 PyTorch DDP,以便可以使用多个 GPU 并行编码大型数据集(修改 `NUM_GPUS` 参数)。
104 changes: 104 additions & 0 deletions examples/diffusers/cogvideox_factory/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import argparse
import pathlib
import shutil

import numpy as np

import mindspore as ms

from mindone import diffusers


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default="THUDM/CogVideoX1.5-5b",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--prompt",
type=str,
default=None,
help="prompt, if None, will set default prompt.",
)
parser.add_argument(
"--transformer_ckpt_path",
type=str,
default=None,
help="Path to the transformer checkpoint.",
)
parser.add_argument(
"--height",
type=int,
default=1360,
help="The height of the output video.",
)
parser.add_argument(
"--width",
type=int,
default=768,
help="The width of the output video.",
)
parser.add_argument(
"--frame",
type=int,
default=77,
help="The frame of the output video.",
)
parser.add_argument(
"--npy_output_path",
type=str,
default=None,
help="Path to save the inferred numpy array.",
)
parser.add_argument(
"--video_output_path",
type=str,
default=None,
help="Path to save the inferred video.",
)
return parser.parse_args()


def infer(args: argparse.Namespace) -> None:
ms.set_context(mode=ms.GRAPH_MODE, jit_config={"jit_level": "O1"})
pipe = diffusers.CogVideoXPipeline.from_pretrained(
args.pretrained_model_name_or_path, mindspore_dtype=ms.bfloat16, use_safetensors=True
)

if args.transformer_ckpt_path is not None:
ckpt = ms.load_checkpoint(args.transformer_ckpt_path)
processed_ckpt = {name[12:]: value for name, value in ckpt.items()} # remove "transformer." prefix
param_not_load, ckpt_not_load = ms.load_param_into_net(pipe.transformer, processed_ckpt)
if param_not_load:
raise RuntimeError(f"{param_not_load} was not loaded into net.")
if ckpt_not_load:
raise RuntimeError(f"{ckpt_not_load} was not loaded from checkpoint.")
print("Successfully loaded transformer checkpoint.")

pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
prompt = prompt if args.prompt is None else args.prompt
video = pipe(prompt=prompt, height=args.height, width=args.width, num_frames=args.frame)[0][0]

if args.npy_output_path is not None:
path = pathlib.Path(args.npy_output_path)
if path.exists():
shutil.rmtree(path)
path.mkdir()
index_max_length = len(str(len(video)))
for index, image in enumerate(video):
np.save(path / f"image_{str(index).zfill(index_max_length)}", np.array(image))
print("Successfully saved the inferred numpy array.")

if args.video_output_path is not None:
diffusers.utils.export_to_video(video, args.video_output_path, fps=8)
print("Successfully saved the inferred video.")


if __name__ == "__main__":
args = parse_args()
infer(args)
20 changes: 11 additions & 9 deletions examples/diffusers/cogvideox_factory/prepare_dataset.sh
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
#!/bin/bash

MODEL_ID="THUDM/CogVideoX-2b"
MODEL_ID="THUDM/CogVideoX1.5-5b"

NUM_NPUS=8
if [ "$NUM_NPUS" -eq 1 ]; then
LAUNCHER="python"
EXTRA_ARGS=""
export HCCL_EXEC_TIMEOUT=1800
else
LAUNCHER="msrun --worker_num=$NUM_NPUS --local_worker_num=$NUM_NPUS"
EXTRA_ARGS="--distributed"
fi

# For more details on the expected data format, please refer to the README.
DATA_ROOT="/path/to/my/datasets/video-dataset" # This needs to be the path to the base directory where your videos are located.
CAPTION_COLUMN="prompt.txt"
CAPTION_COLUMN="prompts.txt"
VIDEO_COLUMN="videos.txt"
OUTPUT_DIR="/path/to/my/datasets/preprocessed-dataset"
HEIGHT_BUCKETS="480"
WIDTH_BUCKETS="720"
FRAME_BUCKETS="49"
MAX_NUM_FRAMES="49"
MAX_SEQUENCE_LENGTH=226
OUTPUT_DIR="preprocessed-dataset"
HEIGHT_BUCKETS="768"
WIDTH_BUCKETS="1360"
FRAME_BUCKETS="77"
MAX_NUM_FRAMES="77"
MAX_SEQUENCE_LENGTH=224
TARGET_FPS=8
BATCH_SIZE=1
DTYPE=bf16
Expand All @@ -45,7 +46,8 @@ CMD_WITHOUT_PRE_ENCODING="\
$EXTRA_ARGS
"

CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_latents_and_embeddings"
CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_embeddings "
CMD_WITH_PRE_ENCODING="$CMD_WITH_PRE_ENCODING --save_latents "

# Select which you'd like to run
CMD=$CMD_WITH_PRE_ENCODING
Expand Down
33 changes: 33 additions & 0 deletions examples/diffusers/cogvideox_factory/tests/run_3d_causal_vae.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# Num of NPUs for test
export MS_DEV_RUNTIME_CONF="memory_statistics:True"
# export MS_DEV_LAZY_FUSION_FLAGS="--opt_level=1"
NUM_NPUS=4
# export DEVICE_ID=3

SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PROJECT_DIR="$(dirname "${SCRIPT_DIR}")"
EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")"
PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")"

export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}"

# Prepare launch cmd according to NUM_NPUS
if [ "$NUM_NPUS" -eq 1 ]; then
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
avg=`expr $cpus \/ 8`
gap=`expr $avg \- 1`
start=`expr $DEVICE_ID \* $avg`
end=`expr $start \+ $gap`
cmdopt=$start"-"$end
LAUNCHER="taskset -c $cmdopt python"
EXTRA_ARGS=""
else
LAUNCHER="msrun --bind_core=True --worker_num=$NUM_NPUS --local_worker_num=$NUM_NPUS --log_dir=log_test_vae --join True"
fi

echo "Start Running:"
cmd="$LAUNCHER ${SCRIPT_DIR}/test_3d_causal_vae.py"
echo "Running command: $cmd"
eval $cmd
echo "=============================================="
Loading
Loading