Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Cui-yshoho committed Nov 30, 2024
1 parent ca0ae89 commit baf9688
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
>>> from controlnet_aux.processor import OpenposeDetector
>>> open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
>>> conditioning_frames = []
>>> for frame in tqdm(video):
... conditioning_frames.append(open_pose(frame))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
>>> import mindspore as ms
>>> from mindone.diffusers import CogView3PlusPipeline
>>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3B", mindspore_dtype=ms.bfloat16)
>>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", mindspore_dtype=ms.bfloat16)
>>> prompt = "A photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt)[0][0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
```py
>>> import numpy as np
>>> import mindspore as ms
>>> from mindone.diffusers.utils import load_image, check_min_version
>>> from mindone.diffusers.utils import load_image
>>> from mindone.diffusers.pipelines import StableDiffusion3ControlNetInpaintingPipeline
>>> from mindone.diffusers.models.controlnet_sd3 import SD3ControlNetModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps=num_inference_steps, strength=strength)
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
latent_timestep = timesteps[:1].tile((batch_size * num_images_per_prompt))
latent_timestep = timesteps[:1].tile((batch_size * num_images_per_prompt,))
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0
self._num_timesteps = len(timesteps)
Expand Down Expand Up @@ -1348,7 +1348,7 @@ def __call__(
# 7.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = ms.tensor(self.guidance_scale - 1).tile((batch_size * num_images_per_prompt))
guidance_scale_tensor = ms.tensor(self.guidance_scale - 1).tile((batch_size * num_images_per_prompt,))
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(dtype=latents.dtype)
Expand Down Expand Up @@ -1438,7 +1438,7 @@ def __call__(
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.add_noise(
init_latents_proper, noise, ms.tensor([noise_timestep])
init_latents_proper, noise, ms.tensor([noise_timestep.item()])
)

latents = (1 - init_mask) * init_latents_proper + init_mask * latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
Examples:
```py
>>> import mindspore as ms
>>> import torch
>>> import numpy as np
>>> from PIL import Image
Expand All @@ -63,7 +64,7 @@
>>> from mindone.diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(
... "diffusers/controlnet-depth-sdxl-1.0-small",
Expand All @@ -83,11 +84,11 @@
... )
>>> def get_depth_map(image):
... image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
... with torch.no_grad(), torch.autocast("cuda"):
... image = feature_extractor(images=image, return_tensors="pt").pixel_values
... with torch.no_grad():
... depth_map = depth_estimator(image).predicted_depth
... depth_map = torch.nn.fuctional.interpolate(
... depth_map = torch.nn.functional.interpolate(
... depth_map.unsqueeze(1),
... size=(1024, 1024),
... mode="bicubic",
Expand All @@ -97,7 +98,7 @@
... depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
... depth_map = (depth_map - depth_min) / (depth_max - depth_min)
... image = torch.cat([depth_map] * 3, dim=1)
... image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
... image = image.permute(0, 2, 3, 1).numpy()[0]
... image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
... return image
Expand Down Expand Up @@ -474,7 +475,7 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.tile((1, num_images_per_prompt, 1))
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.tile(1, num_images_per_prompt).view(
pooled_prompt_embeds = pooled_prompt_embeds.tile((1, num_images_per_prompt)).view(
bs_embed * num_images_per_prompt, -1
)
if do_classifier_free_guidance:
Expand Down

0 comments on commit baf9688

Please sign in to comment.