Skip to content

Commit

Permalink
random broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoting committed Feb 21, 2025
1 parent 414c9c0 commit f6b1474
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 17 deletions.
1 change: 1 addition & 0 deletions examples/diffusers/cogvideox_factory/prepare_dataset.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ NUM_NPUS=8
if [ "$NUM_NPUS" -eq 1 ]; then
LAUNCHER="python"
EXTRA_ARGS=""
export HCCL_EXEC_TIMEOUT=1800
else
LAUNCHER="msrun --worker_num=$NUM_NPUS --local_worker_num=$NUM_NPUS"
EXTRA_ARGS="--distributed"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import shutil
from pathlib import Path
from time import time
from tokenize import group
from typing import Any, Dict, Optional

import numpy as np
Expand Down Expand Up @@ -484,6 +485,7 @@ def optimizer_state_filter(param_name: str):
weight_dtype=weight_dtype,
args=args,
use_rotary_positional_embeddings=transformer_config.use_rotary_positional_embeddings,
enable_sequence_parallelism=enable_sequence_parallelism,
).set_train(True)

loss_scaler = DynamicLossScaleUpdateCell(loss_scale_value=65536.0, scale_factor=2, scale_window=2000)
Expand Down Expand Up @@ -748,6 +750,7 @@ def __init__(
weight_dtype: ms.Type,
args: AttrJitWrapper,
use_rotary_positional_embeddings: bool,
enable_sequence_parallelism: bool = False,
):
super().__init__()

Expand All @@ -764,6 +767,11 @@ def __init__(

self.use_rotary_positional_embeddings = use_rotary_positional_embeddings
self.args = AttrJitWrapper(**vars(args))
self.enable_sequence_parallelism = enable_sequence_parallelism
if self.enable_sequence_parallelism:
from mindone.acceleration import get_sequence_parallel_group

self.broadcast = ops.Broadcast(0, group=get_sequence_parallel_group())

def compute_prompt_embeddings(
self,
Expand All @@ -787,7 +795,7 @@ def diagonal_gaussian_distribution_sample(self, latent_dist: ms.Tensor) -> ms.Te
logvar = ops.clamp(logvar, -30.0, 20.0)
std = ops.exp(0.5 * logvar)

sample = ops.randn_like(mean, dtype=mean.dtype)
sample = self.broadcast(ops.randn_like(mean, dtype=mean.dtype))
x = mean + std * sample

return x
Expand All @@ -809,15 +817,17 @@ def construct(self, videos, text_input_ids_or_prompt_embeds, image_rotary_emb=No
prompt_embeds = text_input_ids_or_prompt_embeds.to(dtype=self.weight_dtype)

# Sample noise that will be added to the latents
noise = ops.randn_like(model_input, dtype=model_input.dtype)
noise = self.broadcast(ops.randn_like(model_input, dtype=model_input.dtype))
batch_size, num_frames, num_channels, height, width = model_input.shape

# Sample a random timestep for each image
timesteps = ops.randint(
0,
self.scheduler_num_train_timesteps,
(batch_size,),
dtype=ms.int64,
timesteps = self.broadcast(
ops.randint(
0,
self.scheduler_num_train_timesteps,
(batch_size,),
dtype=ms.int64,
)
)

# Rotary embeds is Prepared in dataset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def optimizer_state_filter(param_name: str):
weight_dtype=weight_dtype,
args=args,
use_rotary_positional_embeddings=transformer_config.use_rotary_positional_embeddings,
enable_sequence_parallelism=enable_sequence_parallelism,
).set_train(True)

loss_scaler = DynamicLossScaleUpdateCell(loss_scale_value=65536.0, scale_factor=2, scale_window=2000)
Expand Down Expand Up @@ -772,6 +773,7 @@ def __init__(
weight_dtype: ms.Type,
args: AttrJitWrapper,
use_rotary_positional_embeddings: bool,
enable_sequence_parallelism: bool = False,
):
super().__init__()

Expand All @@ -788,6 +790,11 @@ def __init__(

self.use_rotary_positional_embeddings = use_rotary_positional_embeddings
self.args = AttrJitWrapper(**vars(args))
self.enable_sequence_parallelism = enable_sequence_parallelism
if self.enable_sequence_parallelism:
from mindone.acceleration import get_sequence_parallel_group

self.broadcast = ops.Broadcast(0, group=get_sequence_parallel_group())

def compute_prompt_embeddings(
self,
Expand All @@ -811,7 +818,7 @@ def diagonal_gaussian_distribution_sample(self, latent_dist: ms.Tensor) -> ms.Te
logvar = ops.clamp(logvar, -30.0, 20.0)
std = ops.exp(0.5 * logvar)

sample = ops.randn_like(mean, dtype=mean.dtype)
sample = self.broadcast(ops.randn_like(mean, dtype=mean.dtype))
x = mean + std * sample

return x
Expand All @@ -833,15 +840,17 @@ def construct(self, videos, text_input_ids_or_prompt_embeds, image_rotary_emb=No
prompt_embeds = text_input_ids_or_prompt_embeds.to(dtype=self.weight_dtype)
# prompt_embeds(1, 226, 4096)
# Sample noise that will be added to the latents
noise = ops.randn_like(model_input, dtype=model_input.dtype)
noise = self.broadcast(ops.randn_like(model_input, dtype=model_input.dtype))
batch_size, num_frames, num_channels, height, width = model_input.shape

# Sample a random timestep for each image
timesteps = ops.randint(
0,
self.scheduler_num_train_timesteps,
(batch_size,),
dtype=ms.int64,
timesteps = self.broadcast(
ops.randint(
0,
self.scheduler_num_train_timesteps,
(batch_size,),
dtype=ms.int64,
)
)

# Rotary embeds is Prepared in dataset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def collate_fn(data):
save_future.result()

if args.world_size > 1:
ops.Barrier()()
ops.AllGather()(ops.ones((1,), dtype=ms.float32))

# 6. Combine results from each rank
if is_master(args):
Expand Down Expand Up @@ -545,8 +545,7 @@ def rmdir_recursive(dir: pathlib.Path) -> None:
rmdir_recursive(child)
dir.rmdir()

# rmdir_recursive(tmp_dir)
print(f"[WARNING] please delete the tmp dir {tmp_dir}")
rmdir_recursive(tmp_dir)

# Combine prompts and videos into individual text files and single jsonl
prompts_folder = output_dir.joinpath("prompts")
Expand Down

0 comments on commit f6b1474

Please sign in to comment.