Skip to content

Commit

Permalink
add cogvideox SP
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoting committed Feb 17, 2025
1 parent f73a600 commit feec90d
Show file tree
Hide file tree
Showing 35 changed files with 4,349 additions and 326 deletions.
5 changes: 2 additions & 3 deletions examples/diffusers/cogvideox_factory/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ export_to_video(video, "output.mp4", fps=8)
> [!TIP]
> 由于模型和框架的限制,对于训练我们暂时推荐分阶段的训练流程,即先通过[`prepare_dateset.sh`](./prepare_dataset.sh)预处理数据集,然后读取预处理后的数据集通过`train*.sh`进行正式训练。
>
> 在正式训练阶段,需要增加`--load_tensors`参数以支持预处理数据集。建议增加参数`--mindspore_mode=0`以进行静态图训练加速,在`train*.sh`里可通过设置参数`MINDSPORE_MODE=0`实现。
> 在正式训练阶段,需要增加`--embeddings_cache`参数以支持text embeddings预处理。建议增加参数`--mindspore_mode=0`以进行静态图训练加速,在`train*.sh`里可通过设置参数`MINDSPORE_MODE=0`实现。
>
> 具体情况参见[与原仓的差异 & 功能限制](#与原仓的差异功能限制)
Expand Down Expand Up @@ -202,7 +202,6 @@ DTYPE=bf16
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--load_tensors \
--optimizer $optimizer \
--beta1 0.9 \
--beta2 0.95 \
Expand All @@ -211,7 +210,7 @@ DTYPE=bf16
--report_to tensorboard \
--mindspore_mode $MINDSPORE_MODE \
--amp_level $AMP_LEVEL \
--load_tensors \
--embeddings_cache \
$EXTRA_ARGS"

echo "Running command: $cmd"
Expand Down
2 changes: 1 addition & 1 deletion examples/diffusers/cogvideox_factory/assets/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-D

This dataset has been prepared in the expected format and can be used directly. However, directly using the video dataset may cause Out of Memory (OOM) issues on GPUs with smaller VRAM because it requires loading the [VAE](https://huggingface.co/THUDM/CogVideoX-5b/tree/main/vae) (which encodes videos into latent space) and the large [T5-XXL](https://huggingface.co/google/t5-v1_1-xxl/) text encoder. To reduce memory usage, you can use the `training/prepare_dataset.py` script to precompute latents and embeddings.

Fill or modify the parameters in `prepare_dataset.sh` and execute it to get precomputed latents and embeddings (make sure to specify `--save_latents_and_embeddings` to save the precomputed artifacts). If preparing for image-to-video training, make sure to pass `--save_image_latents`, which encodes and saves image latents along with videos. When using these artifacts during training, ensure that you specify the `--load_tensors` flag, or else the videos will be used directly, requiring the text encoder and VAE to be loaded. The script also supports PyTorch DDP so that large datasets can be encoded in parallel across multiple GPUs (modify the `NUM_GPUS` parameter).
Fill or modify the parameters in `prepare_dataset.sh` and execute it to get precomputed latents and embeddings (make sure to specify `--save_latents_and_embeddings` to save the precomputed artifacts). If preparing for image-to-video training, make sure to pass `--save_image_latents`, which encodes and saves image latents along with videos. When using these artifacts during training, ensure that you specify the `--embeddings_cache` and `--latents_cache` flag, or else the videos will be used directly, requiring the text encoder and VAE to be loaded. The script also supports PyTorch DDP so that large datasets can be encoded in parallel across multiple GPUs (modify the `NUM_GPUS` parameter).
2 changes: 1 addition & 1 deletion examples/diffusers/cogvideox_factory/assets/dataset_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ OOM(内存不足),因为它需要加载 [VAE](https://huggingface.co/THUDM

填写或修改 `prepare_dataset.sh` 中的参数并执行它以获得预先计算的潜在变量和嵌入(请确保指定 `--save_latents_and_embeddings`
以保存预计算的工件)。如果准备图像到视频的训练,请确保传递 `--save_image_latents`,它对沙子进行编码,将图像潜在值与视频一起保存。
在训练期间使用这些工件时,确保指定 `--load_tensors` 标志,否则将直接使用视频并需要加载文本编码器和
在训练期间使用这些工件时,确保指定 `--embeddings_cache``--latents_cache` 标志,否则将直接使用视频并需要加载文本编码器和
VAE。该脚本还支持 PyTorch DDP,以便可以使用多个 GPU 并行编码大型数据集(修改 `NUM_GPUS` 参数)。
103 changes: 103 additions & 0 deletions examples/diffusers/cogvideox_factory/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import argparse
import pathlib
import shutil

import numpy as np
import mindspore as ms

from mindone import diffusers


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default="THUDM/CogVideoX1.5-5b",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--prompt",
type=str,
default=None,
help="prompt, if None, will set default prompt.",
)
parser.add_argument(
"--transformer_ckpt_path",
type=str,
default=None,
help="Path to the transformer checkpoint.",
)
parser.add_argument(
"--height",
type=int,
default=1360,
help="The height of the output video.",
)
parser.add_argument(
"--width",
type=int,
default=768,
help="The width of the output video.",
)
parser.add_argument(
"--frame",
type=int,
default=77,
help="The frame of the output video.",
)
parser.add_argument(
"--npy_output_path",
type=str,
default=None,
help="Path to save the inferred numpy array.",
)
parser.add_argument(
"--video_output_path",
type=str,
default=None,
help="Path to save the inferred video.",
)
return parser.parse_args()


def infer(args: argparse.Namespace) -> None:
ms.set_context(mode=ms.GRAPH_MODE, jit_config={"jit_level": "O1"})
pipe = diffusers.CogVideoXPipeline.from_pretrained(
args.pretrained_model_name_or_path, mindspore_dtype=ms.bfloat16, use_safetensors=True
)

if args.transformer_ckpt_path is not None:
ckpt = ms.load_checkpoint(args.transformer_ckpt_path)
processed_ckpt = {name[12:]: value for name, value in ckpt.items()} # remove "transformer." prefix
param_not_load, ckpt_not_load = ms.load_param_into_net(pipe.transformer, processed_ckpt)
if param_not_load:
raise RuntimeError(f"{param_not_load} was not loaded into net.")
if ckpt_not_load:
raise RuntimeError(f"{ckpt_not_load} was not loaded from checkpoint.")
print("Successfully loaded transformer checkpoint.")

pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
prompt = prompt if args.prompt is None else args.prompt
video = pipe(prompt=prompt, height=args.height, width=args.width, num_frames=args.frame)[0][0]

if args.npy_output_path is not None:
path = pathlib.Path(args.npy_output_path)
if path.exists():
shutil.rmtree(path)
path.mkdir()
index_max_length = len(str(len(video)))
for index, image in enumerate(video):
np.save(path / f"image_{str(index).zfill(index_max_length)}", np.array(image))
print("Successfully saved the inferred numpy array.")

if args.video_output_path is not None:
diffusers.utils.export_to_video(video, args.video_output_path, fps=8)
print("Successfully saved the inferred video.")


if __name__ == "__main__":
args = parse_args()
infer(args)
19 changes: 10 additions & 9 deletions examples/diffusers/cogvideox_factory/prepare_dataset.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

MODEL_ID="THUDM/CogVideoX-2b"
MODEL_ID="THUDM/CogVideoX1.5-5b"

NUM_NPUS=8
if [ "$NUM_NPUS" -eq 1 ]; then
Expand All @@ -13,14 +13,14 @@ fi

# For more details on the expected data format, please refer to the README.
DATA_ROOT="/path/to/my/datasets/video-dataset" # This needs to be the path to the base directory where your videos are located.
CAPTION_COLUMN="prompt.txt"
CAPTION_COLUMN="prompts.txt"
VIDEO_COLUMN="videos.txt"
OUTPUT_DIR="/path/to/my/datasets/preprocessed-dataset"
HEIGHT_BUCKETS="480"
WIDTH_BUCKETS="720"
FRAME_BUCKETS="49"
MAX_NUM_FRAMES="49"
MAX_SEQUENCE_LENGTH=226
OUTPUT_DIR="preprocessed-dataset"
HEIGHT_BUCKETS="768"
WIDTH_BUCKETS="1360"
FRAME_BUCKETS="77"
MAX_NUM_FRAMES="77"
MAX_SEQUENCE_LENGTH=224
TARGET_FPS=8
BATCH_SIZE=1
DTYPE=bf16
Expand All @@ -45,7 +45,8 @@ CMD_WITHOUT_PRE_ENCODING="\
$EXTRA_ARGS
"

CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_latents_and_embeddings"
CMD_WITH_PRE_ENCODING="$CMD_WITHOUT_PRE_ENCODING --save_embeddings "
CMD_WITH_PRE_ENCODING="$CMD_WITH_PRE_ENCODING --save_latents "

# Select which you'd like to run
CMD=$CMD_WITH_PRE_ENCODING
Expand Down
33 changes: 33 additions & 0 deletions examples/diffusers/cogvideox_factory/tests/run_3d_causal_vae.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# Num of NPUs for test
export MS_DEV_RUNTIME_CONF="memory_statistics:True"
# export MS_DEV_LAZY_FUSION_FLAGS="--opt_level=1"
NUM_NPUS=4
# export DEVICE_ID=3

SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PROJECT_DIR="$(dirname "${SCRIPT_DIR}")"
EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")"
PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")"

export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}"

# Prepare launch cmd according to NUM_NPUS
if [ "$NUM_NPUS" -eq 1 ]; then
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
avg=`expr $cpus \/ 8`
gap=`expr $avg \- 1`
start=`expr $DEVICE_ID \* $avg`
end=`expr $start \+ $gap`
cmdopt=$start"-"$end
LAUNCHER="taskset -c $cmdopt python"
EXTRA_ARGS=""
else
LAUNCHER="msrun --bind_core=True --worker_num=$NUM_NPUS --local_worker_num=$NUM_NPUS --log_dir=log_test_vae --join True"
fi

echo "Start Running:"
cmd="$LAUNCHER ${SCRIPT_DIR}/test_3d_causal_vae.py"
echo "Running command: $cmd"
eval $cmd
echo "=============================================="
9 changes: 9 additions & 0 deletions examples/diffusers/cogvideox_factory/tests/run_secd_recv.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
NUM_NPUS=4
SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PROJECT_DIR="$(dirname "${SCRIPT_DIR}")"
EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")"
PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")"

export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}"

msrun --master_port=2234 --worker_num=$NUM_NPUS --local_worker_num=$NUM_NPUS --bind_core=True --log_dir="./log_test_send_recv" --join True ${SCRIPT_DIR}/test_send_recv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/sh
export MS_DEV_RUNTIME_CONF="memory_statistics:True"
#export MS_DEV_RUNTIME_CONF="memory_statistics:True,compile_statistics:True"
#export ASCEND_RT_VISIBLE_DEVICES=2,3
NUM_NPUS=2
SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PROJECT_DIR="$(dirname "${SCRIPT_DIR}")"
EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")"
PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")"

export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}"

echo "Start Running:"
msrun --master_port=1234 --worker_num=$NUM_NPUS --local_worker_num=$NUM_NPUS --bind_core=True --log_dir="./log_test_sp_graph" --join True ${SCRIPT_DIR}/test_cogvideox_sequence_parallelism.py
echo "Done. Check the log at './log_test_sp_graph'."
echo "=============================================="
Loading

0 comments on commit feec90d

Please sign in to comment.