Skip to content

Commit

Permalink
add pipe_test
Browse files Browse the repository at this point in the history
  • Loading branch information
Cui-yshoho committed Jan 7, 2025
1 parent 148efe8 commit 41ad433
Show file tree
Hide file tree
Showing 145 changed files with 3,191 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from ...models.layers_compat import pad
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging, scale_lora_layers, unscale_lora_layers
Expand Down Expand Up @@ -458,9 +459,7 @@ def encode_prompt(
max_sequence_length=max_sequence_length,
)

clip_prompt_embeds = ops.pad(
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)
clip_prompt_embeds = pad(clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]))

prompt_embeds = ops.cat([clip_prompt_embeds, t5_prompt_embed], axis=-2)
pooled_prompt_embeds = ops.cat([pooled_prompt_embed, pooled_prompt_2_embed], axis=-1)
Expand Down Expand Up @@ -511,7 +510,7 @@ def encode_prompt(
max_sequence_length=max_sequence_length,
)

negative_clip_prompt_embeds = ops.pad(
negative_clip_prompt_embeds = pad(
negative_clip_prompt_embeds,
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, scale_lora_layers, unscale_lora_layers
from ...utils.mindspore_utils import randn_tensor
Expand Down Expand Up @@ -985,21 +984,7 @@ def _get_add_time_ids(

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=ms.float32)
use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(dtype)
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)

@property
def guidance_scale(self):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import unittest

import numpy as np
import torch
from ddt import data, ddt, unpack
from PIL import Image
from transformers import CLIPTextConfig

import mindspore as ms

from ..pipeline_test_utils import (
THRESHOLD_FP16,
THRESHOLD_FP32,
PipelineTesterMixin,
get_module,
get_pipeline_components,
)

test_cases = [
{"mode": ms.PYNATIVE_MODE, "dtype": "float32"},
{"mode": ms.PYNATIVE_MODE, "dtype": "float16"},
{"mode": ms.GRAPH_MODE, "dtype": "float32"},
{"mode": ms.GRAPH_MODE, "dtype": "float16"},
]


@ddt
class AnimateDiffVideoToVideoControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
cross_attention_dim = 8
block_out_channels = (8, 8)

pipeline_config = [
[
"unet",
"diffusers.models.unets.unet_2d_condition.UNet2DConditionModel",
"mindone.diffusers.models.unets.unet_2d_condition.UNet2DConditionModel",
dict(
block_out_channels=block_out_channels,
layers_per_block=2,
sample_size=8,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
),
],
[
"scheduler",
"diffusers.schedulers.scheduling_ddim.DDIMScheduler",
"mindone.diffusers.schedulers.scheduling_ddim.DDIMScheduler",
dict(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
clip_sample=False,
),
],
[
"controlnet",
"diffusers.models.controlnet.ControlNetModel",
"mindone.diffusers.models.controlnet.ControlNetModel",
dict(
block_out_channels=block_out_channels,
layers_per_block=2,
in_channels=4,
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
cross_attention_dim=cross_attention_dim,
conditioning_embedding_out_channels=(8, 8),
norm_num_groups=1,
),
],
[
"vae",
"diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL",
"mindone.diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL",
dict(
block_out_channels=block_out_channels,
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
),
],
[
"text_encoder",
"transformers.models.clip.modeling_clip.CLIPTextModel",
"mindone.transformers.models.clip.modeling_clip.CLIPTextModel",
dict(
config=CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=cross_attention_dim,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
),
),
],
[
"tokenizer",
"transformers.models.clip.tokenization_clip.CLIPTokenizer",
"transformers.models.clip.tokenization_clip.CLIPTokenizer",
dict(
pretrained_model_name_or_path="hf-internal-testing/tiny-random-clip",
),
],
[
"motion_adapter",
"diffusers.models.unets.unet_motion_model.MotionAdapter",
"mindone.diffusers.models.unets.unet_motion_model.MotionAdapter",
dict(
block_out_channels=block_out_channels,
motion_layers_per_block=2,
motion_norm_num_groups=2,
motion_num_attention_heads=4,
),
],
]

def get_dummy_components(self):
components = {
key: None
for key in [
"unet",
"controlnet",
"scheduler",
"vae",
"motion_adapter",
"text_encoder",
"tokenizer",
"feature_extractor",
"image_encoder",
]
}

return get_pipeline_components(components, self.pipeline_config)

def get_dummy_inputs(self, num_frames: int = 2):
video_height = 32
video_width = 32
video = [Image.new("RGB", (video_width, video_height))] * num_frames

video_height = 32
video_width = 32
conditioning_frames = [Image.new("RGB", (video_width, video_height))] * num_frames

inputs = {
"video": video,
"conditioning_frames": conditioning_frames,
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs

@data(*test_cases)
@unpack
def test_inference(self, mode, dtype):
ms.set_context(mode=mode)

pt_components, ms_components = self.get_dummy_components()
pt_pipe_cls = get_module(
"diffusers.pipelines.animatediff.pipeline_animatediff_video2video_controlnet.AnimateDiffVideoToVideoControlNetPipeline"
)
ms_pipe_cls = get_module(
"mindone.diffusers.pipelines.animatediff.pipeline_animatediff_video2video_controlnet.AnimateDiffVideoToVideoControlNetPipeline"
)

pt_pipe = pt_pipe_cls(**pt_components)
ms_pipe = ms_pipe_cls(**ms_components)

pt_pipe.set_progress_bar_config(disable=None)
ms_pipe.set_progress_bar_config(disable=None)

ms_dtype, pt_dtype = getattr(ms, dtype), getattr(torch, dtype)
pt_pipe = pt_pipe.to(pt_dtype)
ms_pipe = ms_pipe.to(ms_dtype)

inputs = self.get_dummy_inputs()

torch.manual_seed(0)
pt_frame = pt_pipe(**inputs)
torch.manual_seed(0)
ms_frame = ms_pipe(**inputs)

pt_image_slice = pt_frame.frames[0][0, -3:, -3:, -1]
ms_image_slice = ms_frame[0][0][0, -3:, -3:, -1]

threshold = THRESHOLD_FP32 if dtype == "float32" else THRESHOLD_FP16
assert np.max(np.linalg.norm(pt_image_slice - ms_image_slice) / np.linalg.norm(pt_image_slice)) < threshold
Loading

0 comments on commit 41ad433

Please sign in to comment.