Skip to content

Commit

Permalink
fix_bugs_animatediff
Browse files Browse the repository at this point in the history
  • Loading branch information
Cui-yshoho committed Jan 2, 2025
1 parent bf21132 commit 6ff5acf
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
6 changes: 4 additions & 2 deletions mindone/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def construct(
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)

hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(input=hidden_states)

# 2. Blocks
for block in self.transformer_blocks:
Expand All @@ -186,7 +186,7 @@ def construct(
)

# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(input=hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, num_frames, channel)
Expand Down Expand Up @@ -656,6 +656,7 @@ def construct(
hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1)

hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -786,6 +787,7 @@ def construct(
hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1)

hidden_states = resnet(input_tensor=hidden_states, temb=temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

if self.upsamplers is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def encode_prompt(
# dynamically adjust the LoRA scale
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
Expand Down Expand Up @@ -497,8 +497,8 @@ def check_inputs(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")

if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
Expand Down Expand Up @@ -664,6 +664,10 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

def __call__(
self,
prompt: Union[str, List[str]] = None,
Expand Down Expand Up @@ -822,6 +826,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -861,6 +866,8 @@ def __call__(
if self.do_classifier_free_guidance:
prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds])

prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
Expand Down Expand Up @@ -940,6 +947,9 @@ def __call__(
# 8. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand All @@ -952,7 +962,6 @@ def __call__(
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)

if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@ def __call__(
if self.do_classifier_free_guidance:
prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds])

prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)

# 4. Prepare IP-Adapter embeddings
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Expand Down

0 comments on commit 6ff5acf

Please sign in to comment.