diff --git a/mindone/diffusers/models/attention_processor.py b/mindone/diffusers/models/attention_processor.py index 000bb84630..15b8263d5a 100644 --- a/mindone/diffusers/models/attention_processor.py +++ b/mindone/diffusers/models/attention_processor.py @@ -734,7 +734,7 @@ def flash_attention_op( # process `attn_mask` as logic is different between PyTorch and Mindspore # In MindSpore, False indicates retention and True indicates discard, in PyTorch it is the opposite if attn_mask is not None: - attn_mask = ops.logical_not(attn_mask) if attn_mask.dtype == ms.bool_ else attn_mask.bool() + attn_mask = ops.logical_not(attn_mask) attn_mask = ops.broadcast_to( attn_mask, (attn_mask.shape[0], attn_mask.shape[1], query.shape[-2], key.shape[-2]) )[:, :1, :, :] @@ -2103,6 +2103,9 @@ def __call__( ) if attention_mask is not None: + # bool input dtype is not supported in tensor.repeat_interleave + if attention_mask.dtype == ms.bool_: + attention_mask = attention_mask.to(hidden_states.dtype) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) diff --git a/mindone/diffusers/schedulers/scheduling_euler_discrete.py b/mindone/diffusers/schedulers/scheduling_euler_discrete.py index 6b9baa5b23..3a031f6dfe 100644 --- a/mindone/diffusers/schedulers/scheduling_euler_discrete.py +++ b/mindone/diffusers/schedulers/scheduling_euler_discrete.py @@ -613,7 +613,11 @@ def add_noise( # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten()