Skip to content

Commit

Permalink
Fix latest CANN bugs (repeat ops) on OpenSora1.1 (#595)
Browse files Browse the repository at this point in the history
* fix cann0630 bugs on repeat ops

* vae default bf16

* fix repeat/tile ops in latest cann

* update docks
  • Loading branch information
SamitHuang authored Jul 16, 2024
1 parent 9adb253 commit 960cfef
Show file tree
Hide file tree
Showing 12 changed files with 303 additions and 20 deletions.
11 changes: 6 additions & 5 deletions examples/opensora_hpcai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,12 @@ We evaluated the training performance on MindSpore and Ascend NPUs. The results

| Model | Context | jit_level | Precision | BS | NPUs | Resolution(framesxHxW) | Train T. (s/step) |
|:------------|:-------------|:--------|:---------:|:--:|:----:|:----------------------:|:-----------------:|
| STDiT2-XL/2 | D910\*-MS2.3 | O1 | BF16 | 1 | 8 | 16x512x512 | 2.00 |
| STDiT2-XL/2 | D910\*-MS2.3 | O1 | BF16 | 1 | 8 | 64x512x512 | 8.30 |
| STDiT2-XL/2 | D910\*-MS2.3 | O1 | BF16 | 1 | 8 | 24x576x1024 | 8.22 |
| STDiT2-XL/2 | D910\*-MS2.3 | O1 | BF16 | 1 | 8 | 64x576x1024 | 21.15 |
| STDiT2-XL/2 | D910\*-MS2.3 | O1 | BF16 | 1 | 8 | 24x1024x1024 | 16.98 |
| STDiT2-XL/2 | D910\*-[CANN C18(0517)](https://repo.mindspore.cn/ascend/ascend910/20240517/)-[MS2.3_master(0615)](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/master_20240615020018_43ccb91e45899b64fe31d304497ab17e3ada3cea_newest/unified/) | O1 | BF16 | 1 | 8 | 16x512x512 | 2.00 |
| STDiT2-XL/2 | D910\*-[CANN C18(0517)](https://repo.mindspore.cn/ascend/ascend910/20240517/)-[MS2.3_master(0615)](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/master_20240615020018_43ccb91e45899b64fe31d304497ab17e3ada3cea_newest/unified/) | O1 | BF16 | 1 | 8 | 64x512x512 | 8.30 |
| STDiT2-XL/2 | D910\*-[CANN C18(0517)](https://repo.mindspore.cn/ascend/ascend910/20240517/)-[MS2.3_master(0615)](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/master_20240615020018_43ccb91e45899b64fe31d304497ab17e3ada3cea_newest/unified/) | O1 | BF16 | 1 | 8 | 24x576x1024 | 8.22 |
| STDiT2-XL/2 | D910\*-[CANN C18(0705)](https://repo.mindspore.cn/ascend/ascend910/20240705/)-[MS2.3_master(0705)](https://repo.mindspore.cn/mindspore/mindspore/version/202407/20240705/master_20240705220018_51f414917fd9a312dd43ea62eea61cf37c3dfbd6_newest/unified/) | O1 | BF16 | 1 | 8 | 24x576x1024 | 7.82 |
| STDiT2-XL/2 | D910\*-[CANN C18(0517)](https://repo.mindspore.cn/ascend/ascend910/20240517/)-[MS2.3_master(0615)](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/master_20240615020018_43ccb91e45899b64fe31d304497ab17e3ada3cea_newest/unified/) | O1 | BF16 | 1 | 8 | 64x576x1024 | 21.15 |
| STDiT2-XL/2 | D910\*-[CANN C18(0517)](https://repo.mindspore.cn/ascend/ascend910/20240517/)-[MS2.3_master(0615)](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/master_20240615020018_43ccb91e45899b64fe31d304497ab17e3ada3cea_newest/unified/) | O1 | BF16 | 1 | 8 | 24x1024x1024 | 16.98 |
> Context: {G:GPU, D:Ascend}{chip type}-{mindspore version}.
>Note that the above performance uses both t5 cached embedding data and vae cached latent data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ max_rowsize: 256

# precision
amp_level: "O2"
vae_dtype: "bf16"
dtype: "bf16"
init_loss_scale: 65536

Expand All @@ -35,8 +36,8 @@ optim: "adamw_re"
optim_eps: 1.e-8
weight_decay: 0.

epochs: 12000
ckpt_save_interval: 100
epochs: 10000
ckpt_save_interval: 200

mask_ratios:
mask_no: 0.75
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ max_rowsize: 256
amp_level: "O2"
dtype: "bf16"
init_loss_scale: 65536
vae_dtype: "bf16"
# vae_micro_batch_size: 8

# training hyper-params
scheduler: "constant"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def _patchify(self, latent: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndar

spatial_mask = np.ones(spatial_pos.shape[0], dtype=np.uint8)
temporal_mask = np.ones(temporal_pos.shape[0], dtype=np.uint8)

return latent, spatial_pos, spatial_mask, temporal_pos, temporal_mask

def __len__(self):
Expand Down
8 changes: 4 additions & 4 deletions examples/opensora_hpcai/opensora/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def construct(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = No

if mask is not None:
# (b 1 n_k) -> (b*h 1 n_k)
mask = ops.repeat_interleave(mask, h, axis=0)
# NOTE: due to uint8 not supported in CANN0630, cast mask to int32
mask = ops.repeat_interleave(mask.to(ms.int32), h, axis=0)
mask = mask.to(ms.bool_)
sim = ops.masked_fill(sim, mask, -ms.numpy.inf)

Expand Down Expand Up @@ -174,8 +175,7 @@ def construct(self, x, cond, mask=None):
# (b n_k) -> (b 1 1 n_k), will be broadcast according to qk sim, e.g. (b num_heads n_q n_k)
mask = mask[:, None, None, :]
# (b 1 1 n_k) -> (b 1 n_q n_k)
# mask = ops.repeat_interleave(mask.to(ms.uint8), q.shape[-2], axis=-2)
mask = ops.repeat_interleave(mask, int(q.shape[1]), axis=-2)
mask = ops.repeat_interleave(mask.to(ms.int32), int(q.shape[1]), axis=-2)
x = self.flash_attention(q, k, v, mask=mask)

# FA attn_mask def: retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`, `(B, 1, S1, S2)` `(S1, S2)`
Expand Down Expand Up @@ -280,7 +280,7 @@ def construct(self, x, mask=None, freqs_cis: Optional[Tensor] = None):
if mask is not None:
mask = mask[:, None, None, :]
# mask: (b n_k) -> (b 1 n_q n_k)
mask = ops.repeat_interleave(mask, int(q.shape[1]), axis=-2)
mask = ops.repeat_interleave(mask.to(ms.int32), int(q.shape[1]), axis=-2)
out = self.flash_attention(q, k, v, mask=mask)
else:
if mask is not None:
Expand Down
5 changes: 3 additions & 2 deletions examples/opensora_hpcai/opensora/models/stdit/stdit2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from opensora.models.layers.rotary_embedding import RotaryEmbedding

import mindspore as ms
from mindspore import Parameter, Tensor, dtype, load_checkpoint, load_param_into_net, nn, ops
from mindspore.common.initializer import XavierUniform, initializer

Expand Down Expand Up @@ -137,7 +138,7 @@ def construct(
# spatial branch
x_s = x_m.reshape(B * T, S, C) # B (T S) C -> (B T) S C
if spatial_mask is not None:
spatial_mask = ops.repeat_interleave(spatial_mask, T, axis=0) # B S -> (B T) S
spatial_mask = ops.repeat_interleave(spatial_mask.to(ms.int32), T, axis=0) # B S -> (B T) S
x_s = self.attn(x_s, mask=spatial_mask)
x_s = x_s.reshape(B, T * S, C) # (B T) S C -> B (T S) C

Expand All @@ -158,7 +159,7 @@ def construct(
# temporal branch
x_t = x_m.reshape(B, T, S, C).swapaxes(1, 2).reshape(B * S, T, C) # B (T S) C -> (B S) T C
if temporal_mask is not None:
temporal_mask = ops.repeat_interleave(temporal_mask, S, axis=0) # B T -> (B S) T
temporal_mask = ops.repeat_interleave(temporal_mask.to(ms.int32), S, axis=0) # B T -> (B S) T
x_t = self.attn_temp(x_t, mask=temporal_mask, freqs_cis=temporal_pos)
x_t = x_t.reshape(B, S, T, C).swapaxes(1, 2).reshape(B, T * S, C) # (B S) T C -> B (T S) C

Expand Down
3 changes: 2 additions & 1 deletion examples/opensora_hpcai/opensora/pipelines/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def compute_loss(

# Learn the variance using the variational bound, but don't let it affect our mean prediction.
patch_mask = temporal_mask[:, :, None, None] * spatial_mask[:, None, :, None]
patch_mask = self.unpatchify(ops.tile(patch_mask, (1, 1, 1, D))) # b c t h w
pm_dtype = patch_mask.dtype
patch_mask = self.unpatchify(ops.tile(patch_mask.to(ms.int32), (1, 1, 1, D)).to(pm_dtype)) # b c t h w
vb = self._cal_vb(
ops.stop_gradient(model_output),
model_var_values,
Expand Down
2 changes: 1 addition & 1 deletion examples/opensora_hpcai/scripts/args_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def parse_train_args(parser):
help="Specify the [beta1, beta2] parameter for the AdamW optimizer.",
)
parser.add_argument(
"--optim_eps", type=float, default=1e-6, help="Specify the eps parameter for the AdamW optimizer."
"--optim_eps", type=float, default=1e-8, help="Specify the eps parameter for the AdamW optimizer."
)
parser.add_argument(
"--group_strategy",
Expand Down
5 changes: 4 additions & 1 deletion examples/opensora_hpcai/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,10 @@ def main(args):
if args.dtype in ["fp16", "bf16"]:
if not args.global_bf16:
latte_model = auto_mixed_precision(
latte_model, amp_level=args.amp_level, dtype=dtype_map[args.dtype], custom_fp32_cells=WHITELIST_OPS
latte_model,
amp_level=args.amp_level,
dtype=dtype_map[args.dtype],
custom_fp32_cells=WHITELIST_OPS,
)
# load checkpoint
if len(args.pretrained_model_path) > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,12 @@ def construct(
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# b 1 l -> (b f) 1 l
encoder_attention_mask = encoder_attention_mask.repeat_interleave(frame, dim=0)
encoder_attention_mask = encoder_attention_mask.to(ms.int32).repeat_interleave(frame, dim=0)
encoder_attention_mask = encoder_attention_mask.to(self.dtype)
elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
encoder_attention_mask_video = encoder_attention_mask_video.repeat_interleave(frame, dim=1)
encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
encoder_attention_mask_video = encoder_attention_mask_video.to(ms.int32).repeat_interleave(frame, dim=1)
encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...].to(ms.int32)
encoder_attention_mask = ops.cat([encoder_attention_mask_video, encoder_attention_mask_image], axis=1)
# b n l -> (b n) l
encoder_attention_mask = encoder_attention_mask.view(-1, encoder_attention_mask.shape[-1]).unsqueeze(1)
Expand Down
Loading

0 comments on commit 960cfef

Please sign in to comment.