Skip to content

Commit

Permalink
update sd3_controlnet_inpaint to 0.31.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Cui-yshoho committed Nov 29, 2024
1 parent 2c2e37d commit 9f4fba9
Show file tree
Hide file tree
Showing 5 changed files with 1,131 additions and 33 deletions.
3 changes: 2 additions & 1 deletion mindone/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3ControlNetInpaintingPipeline" "StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3Pipeline",
Expand Down Expand Up @@ -343,6 +343,7 @@
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableDiffusion3ControlNetInpaintingPipeline,
StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
Expand Down
3 changes: 2 additions & 1 deletion mindone/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
],
"controlnet_sd3": [
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3ControlNetInpaintingPipeline",
],
"dance_diffusion": ["DanceDiffusionPipeline"],
"ddim": ["DDIMPipeline"],
Expand Down Expand Up @@ -186,7 +187,7 @@
StableDiffusionXLControlNetPipeline,
)
from .controlnet_hunyuandit import HunyuanDiTControlNetPipeline
from .controlnet_sd3 import StableDiffusion3ControlNetPipeline
from .controlnet_sd3 import StableDiffusion3ControlNetInpaintingPipeline, StableDiffusion3ControlNetPipeline
from .controlnet_xs import StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
from transformers import CLIPImageProcessor, CLIPTokenizer

import mindspore as ms
from mindspore import ops

from mindone.transformers import CLIPTextModel, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPTokenizer

from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
Expand All @@ -37,19 +39,20 @@
from ...utils.mindspore_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..controlnet.multicontrolnet import MultiControlNetModel

# from ..free_init_utils import FreeInitMixin
# from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput


logger = logging.get_logger(__name__) # pylint: disable=invalid-name

EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import mindspore as ms
>>> from PIL import Image
>>> import numpy as np
>>> from tqdm.auto import tqdm
>>> from mindone.diffusers import AnimateDiffVideoToVideoControlNetPipeline
Expand Down Expand Up @@ -338,9 +341,7 @@ def encode_prompt(
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
Expand Down Expand Up @@ -455,9 +456,7 @@ def encode_image(self, image, num_images_per_prompt, output_hidden_states=None):
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True)[2][-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
ops.zeros_like(image), output_hidden_states=True
)[2][-2]
uncond_image_enc_hidden_states = self.image_encoder(ops.zeros_like(image), output_hidden_states=True)[2][-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
Expand Down Expand Up @@ -644,19 +643,15 @@ def check_inputs(
)
num_frames = len(video) if latents is None else latents.shape[2]

if (
isinstance(self.controlnet, ControlNetModel)
or isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if isinstance(self.controlnet, ControlNetModel) or isinstance(self.controlnet._orig_mod, ControlNetModel):
if not isinstance(conditioning_frames, list):
raise TypeError(
f"For single controlnet, `image` must be of type `list` but got {type(conditioning_frames)}"
)
if len(conditioning_frames) != num_frames:
raise ValueError(f"Excepted image to have length {num_frames} but got {len(conditioning_frames)=}")
elif (
isinstance(self.controlnet, MultiControlNetModel)
or isinstance(self.controlnet._orig_mod, MultiControlNetModel)
elif isinstance(self.controlnet, MultiControlNetModel) or isinstance(
self.controlnet._orig_mod, MultiControlNetModel
):
if not isinstance(conditioning_frames, list) or not isinstance(conditioning_frames[0], list):
raise TypeError(
Expand All @@ -672,15 +667,11 @@ def check_inputs(
assert False

# Check `controlnet_conditioning_scale`
if (
isinstance(self.controlnet, ControlNetModel)
or isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if isinstance(self.controlnet, ControlNetModel) or isinstance(self.controlnet._orig_mod, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif (
isinstance(self.controlnet, MultiControlNetModel)
or isinstance(self.controlnet._orig_mod, MultiControlNetModel)
elif isinstance(self.controlnet, MultiControlNetModel) or isinstance(
self.controlnet._orig_mod, MultiControlNetModel
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
Expand Down Expand Up @@ -776,8 +767,7 @@ def prepare_latents(
)

init_latents = [
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
for i in range(batch_size)
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0) for i in range(batch_size)
]
else:
init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
Expand Down Expand Up @@ -832,9 +822,7 @@ def prepare_conditioning_frames(
do_classifier_free_guidance=False,
guess_mode=False,
):
video = self.control_video_processor.preprocess_video(video, height=height, width=width).to(
dtype=ms.float32
)
video = self.control_video_processor.preprocess_video(video, height=height, width=width).to(dtype=ms.float32)
video = video.permute(0, 2, 1, 3, 4).flatten(start_dim=0, end_dim=1)
video_batch_size = video.shape[0]

Expand Down Expand Up @@ -1070,9 +1058,7 @@ def __call__(

# 3. Prepare timesteps
if not enforce_inference_steps:
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, timesteps, sigmas
)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength)
latent_timestep = timesteps[:1].tile((batch_size * num_videos_per_prompt,))
else:
Expand Down Expand Up @@ -1286,4 +1272,4 @@ def __call__(
if not return_dict:
return (video,)

return AnimateDiffPipelineOutput(frames=video)
return AnimateDiffPipelineOutput(frames=video)
5 changes: 5 additions & 0 deletions mindone/diffusers/pipelines/controlnet_sd3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@

_import_structure = {}
_import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"]
_import_structure["pipeline_stable_diffusion_3_controlnet_inpainting"] = [
"StableDiffusion3ControlNetInpaintingPipeline"
]


if TYPE_CHECKING:
from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline
from .pipeline_stable_diffusion_3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline
else:
import sys

Expand Down
Loading

0 comments on commit 9f4fba9

Please sign in to comment.