Skip to content

Commit

Permalink
DiT now supports sequence conditions. (#923)
Browse files Browse the repository at this point in the history
When using seq2seq models with DiT, the condition may have the same sequence
length as the input.
For example:
- Input shape: `[batch, seq_len, dim]`
- Condition shape: `[batch, seq_len, cond_dim]`

AdaptiveLayerNormModulation now supports conditions in both `[batch, cond_dim]`
and `[batch, seq_len, cond_dim]` formats. It outputs conditions in the shape
`[batch, 1|seq_len, cond_dim]`, depending on whether `seq_len` is present.

Accordingly, DiT has been updated to handle rank-3 conditions.
The codebase has also become simpler. Previously, `jnp.expand_dims` was
scattered across many places, but now `AdaptiveLayerNormModulation` adjusts the
rank of the condition to match the input and returns it accordingly.

Speech detokenizer will use this DiT.
  • Loading branch information
ds-hwang authored Jan 14, 2025
1 parent feb8357 commit a946f91
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 67 deletions.
71 changes: 46 additions & 25 deletions axlearn/common/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from typing import Optional, Union

import chex
import einops
import jax
import jax.numpy as jnp

Expand All @@ -31,7 +33,21 @@


def modulate(*, x, shift, scale):
return x * (1 + jnp.expand_dims(scale, 1)) + jnp.expand_dims(shift, 1)
"""Modulates the input x tensor.
Note: shift and scale must have the same shape.
Args:
x: input tensor with shape [batch_size, num_length, input_dim].
shift: shifting the norm tensor with shape [batch_size, 1|num_length, input_dim].
scale: scaling the norm tensor with shape [batch_size, 1|num_length, input_dim].
Returns:
A tensor with shape [batch_size, num_length, input_dim].
"""
chex.assert_equal_shape((shift, scale))
chex.assert_equal_rank((x, shift, scale))
return x * (1 + scale) + shift


class TimeStepEmbedding(BaseLayer):
Expand Down Expand Up @@ -211,15 +227,18 @@ def forward(self, input: Tensor) -> Tensor:
"""Generate the parameters for modulation.
Args:
input: A tensor with shape [batch_size, ..., dim].
input: A tensor with shape [batch_size, dim] or [batch_size, num_length, dim].
Returns:
A list of tensors with length num_outputs.
Each tensor has shape [batch_size, ..., dim].
Each tensor has shape [batch_size, 1|num_length, dim].
"""
cfg = self.config
x = get_activation_fn(cfg.activation)(input)
output = self.linear(x)
assert output.ndim in (2, 3)
if output.ndim == 2:
output = einops.rearrange(output, "b d -> b 1 d")
output = jnp.split(output, cfg.num_outputs, axis=-1)
return output

Expand Down Expand Up @@ -292,14 +311,16 @@ def forward(self, *, input: Tensor, shift: Tensor, scale: Tensor, gate: Tensor)
Args:
input: input tensor with shape [batch_size, num_length, input_dim].
shift: shifting the norm tensor with shape [batch_size, input_dim].
scale: scaling the norm tensor with shape [batch_size, input_dim].
shift: shifting the norm tensor with shape [batch_size, 1|num_length, input_dim].
scale: scaling the norm tensor with shape [batch_size, 1|num_length, input_dim].
gate: applying before the residual addition with shape
[batch_size, input_dim].
[batch_size, 1|num_length, input_dim].
Returns:
A tensor with shape [batch_size, num_length, input_dim].
"""
chex.assert_equal_shape((shift, scale, gate))
chex.assert_equal_rank((input, shift))
cfg = self.config
remat_pt1 = "linear1_0"
remat_pt2 = "linear2"
Expand All @@ -325,7 +346,7 @@ def forward(self, *, input: Tensor, shift: Tensor, scale: Tensor, gate: Tensor)
x = self.postnorm(x)

x = self.dropout2(x)
x = x * jnp.expand_dims(gate, 1)
x = x * gate
x += input
return x

Expand Down Expand Up @@ -389,12 +410,12 @@ def forward(
Args:
input: input tensor with shape [batch_size, num_length, target_dim].
shift: If provided, shifting the norm tensor with shape [batch_size, target_dim] and
scale should be provided.
scale: If provided, scaling the norm tensor with shape [batch_size, target_dim] and
shift should be provided.
shift: If provided, shifting the norm tensor with shape [batch_size, 1|num_length,
target_dim] and scale should be provided.
scale: If provided, scaling the norm tensor with shape [batch_size, 1|num_length,
target_dim] and shift should be provided.
gate: If provided, applying before the residual addition with shape
[batch_size, target_dim].
[batch_size, 1|num_length, target_dim].
attention_logit_biases: Optional Tensor representing the self attention biases.
Returns:
Expand Down Expand Up @@ -426,7 +447,7 @@ def forward(
x = self.postnorm(x)

if gate is not None:
x = x * jnp.expand_dims(gate, 1)
x = x * gate

output = input + x
return output
Expand Down Expand Up @@ -463,12 +484,12 @@ def extend_step(
results of previous attentions, and index used for fast decoding. Contains
"attention" cached states.
target: target tensor with shape [batch_size, step_length, target_dim].
shift: If provided, shifting the norm tensor with shape [batch_size, target_dim] and
scale should be provided.
scale: If provided, scaling the norm tensor with shape [batch_size, target_dim] and
shift should be provided.
shift: If provided, shifting the norm tensor with shape [batch_size, 1|num_length,
target_dim] and scale should be provided.
scale: If provided, scaling the norm tensor with shape [batch_size, 1|num_length,
target_dim] and shift should be provided.
gate: If provided, applying before the residual addition with shape
[batch_size, target_dim].
[batch_size, 1|num_length, target_dim].
Returns:
A tuple (cached_states, output):
Expand Down Expand Up @@ -504,7 +525,7 @@ def extend_step(
x = self.postnorm(x)

if gate is not None:
x = x * jnp.expand_dims(gate, 1)
x = x * gate

output = target + x
return dict(attention=attn_states), output
Expand Down Expand Up @@ -542,8 +563,8 @@ def forward(self, *, input: Tensor, condition: Tensor) -> Tensor:
Args:
input: input tensor with shape [batch_size, num_length, input_dim].
condition: tensor with shape [batch_size, input_dim] for generating
layer norm shift, scale, and gate.
condition: tensor with shape [batch_size, input_dim] or [batch_size, num_length,
input_dim] for generating layer norm shift, scale, and gate.
Returns:
A tensor with shape [batch_size, num_length, input_dim].
Expand Down Expand Up @@ -584,8 +605,8 @@ def extend_step(
results of previous attentions, and index used for fast decoding. Contains
"attention" cached states.
target: target tensor with shape [batch_size, step_length, input_dim].
condition: tensor with shape [batch_size, input_dim] for generating
layer norm shift, scale, and gate.
condition: tensor with shape [batch_size, input_dim] or [batch_size, step_length,
input_dim] for generating layer norm shift, scale, and gate.
Returns:
A tuple (cached_states, output):
Expand Down Expand Up @@ -639,8 +660,8 @@ def forward(self, *, input: Tensor, condition: Tensor) -> Tensor:
Args:
input: input tensor with shape [batch_size, num_length, input_dim].
condition: tensor with shape [batch_size, input_dim] for generating
layer norm shift and scale.
condition: tensor with shape [batch_size, input_dim] or [batch_size, num_length,
input_dim] for generating layer norm shift and scale.
Returns:
A tensor with shape [batch_size, num_length, output_dim].
Expand Down
Loading

0 comments on commit a946f91

Please sign in to comment.