Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiT now supports sequence conditions. #923

Merged
merged 1 commit into from
Jan 14, 2025
Merged

DiT now supports sequence conditions. #923

merged 1 commit into from
Jan 14, 2025

Conversation

ds-hwang
Copy link
Contributor

When using seq2seq models with DiT, the condition may have the same sequence length as the input.
For example:

  • Input shape: [batch, seq_len, dim]
  • Condition shape: [batch, seq_len, cond_dim]

AdaptiveLayerNormModulation now supports conditions in both [batch, cond_dim] and [batch, seq_len, cond_dim] formats. It outputs conditions in the shape [batch, 1|seq_len, cond_dim], depending on whether seq_len is present.

Accordingly, DiT has been updated to handle rank-3 conditions. The codebase has also become simpler. Previously, jnp.expand_dims was scattered across many places, but now AdaptiveLayerNormModulation adjusts the rank of the condition to match the input and returns it accordingly.

Speech detokenizer will use this DiT.

When using seq2seq models with DiT, the condition may have the same sequence
length as the input.
For example:
- Input shape: `[batch, seq_len, dim]`
- Condition shape: `[batch, seq_len, cond_dim]`

AdaptiveLayerNormModulation now supports conditions in both `[batch, cond_dim]`
and `[batch, seq_len, cond_dim]` formats. It outputs conditions in the shape
`[batch, 1|seq_len, cond_dim]`, depending on whether `seq_len` is present.

Accordingly, DiT has been updated to handle rank-3 conditions.
The codebase has also become simpler. Previously, `jnp.expand_dims` was
scattered across many places, but now `AdaptiveLayerNormModulation` adjusts the
rank of the condition to match the input and returns it accordingly.

Speech detokenizer will use this DiT.
@ds-hwang ds-hwang requested review from ruomingp, markblee and a team as code owners January 14, 2025 05:09
@ds-hwang
Copy link
Contributor Author

@ruomingp Could you take a look? From 979

@ds-hwang ds-hwang enabled auto-merge January 14, 2025 05:09
"""
cfg = self.config
x = get_activation_fn(cfg.activation)(input)
output = self.linear(x)
assert output.ndim in (2, 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise a ValueError instead of assert (which should only be used to enforce internal logic errors and cannot be triggered by user error).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I clicked automerge, this your comment was not handled. I submitted #981 to handle it.

@ds-hwang ds-hwang added this pull request to the merge queue Jan 14, 2025
Merged via the queue into apple:main with commit a946f91 Jan 14, 2025
6 checks passed
@ds-hwang ds-hwang deleted the dit branch January 14, 2025 06:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants