Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiT now supports sequence conditions. #923

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 46 additions & 25 deletions axlearn/common/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from typing import Optional, Union

import chex
import einops
import jax
import jax.numpy as jnp

Expand All @@ -31,7 +33,21 @@


def modulate(*, x, shift, scale):
return x * (1 + jnp.expand_dims(scale, 1)) + jnp.expand_dims(shift, 1)
"""Modulates the input x tensor.

Note: shift and scale must have the same shape.

Args:
x: input tensor with shape [batch_size, num_length, input_dim].
shift: shifting the norm tensor with shape [batch_size, 1|num_length, input_dim].
scale: scaling the norm tensor with shape [batch_size, 1|num_length, input_dim].

Returns:
A tensor with shape [batch_size, num_length, input_dim].
"""
chex.assert_equal_shape((shift, scale))
chex.assert_equal_rank((x, shift, scale))
return x * (1 + scale) + shift


class TimeStepEmbedding(BaseLayer):
Expand Down Expand Up @@ -211,15 +227,18 @@ def forward(self, input: Tensor) -> Tensor:
"""Generate the parameters for modulation.

Args:
input: A tensor with shape [batch_size, ..., dim].
input: A tensor with shape [batch_size, dim] or [batch_size, num_length, dim].

Returns:
A list of tensors with length num_outputs.
Each tensor has shape [batch_size, ..., dim].
Each tensor has shape [batch_size, 1|num_length, dim].
"""
cfg = self.config
x = get_activation_fn(cfg.activation)(input)
output = self.linear(x)
assert output.ndim in (2, 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise a ValueError instead of assert (which should only be used to enforce internal logic errors and cannot be triggered by user error).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I clicked automerge, this your comment was not handled. I submitted #981 to handle it.

if output.ndim == 2:
output = einops.rearrange(output, "b d -> b 1 d")
output = jnp.split(output, cfg.num_outputs, axis=-1)
return output

Expand Down Expand Up @@ -292,14 +311,16 @@ def forward(self, *, input: Tensor, shift: Tensor, scale: Tensor, gate: Tensor)

Args:
input: input tensor with shape [batch_size, num_length, input_dim].
shift: shifting the norm tensor with shape [batch_size, input_dim].
scale: scaling the norm tensor with shape [batch_size, input_dim].
shift: shifting the norm tensor with shape [batch_size, 1|num_length, input_dim].
scale: scaling the norm tensor with shape [batch_size, 1|num_length, input_dim].
gate: applying before the residual addition with shape
[batch_size, input_dim].
[batch_size, 1|num_length, input_dim].

Returns:
A tensor with shape [batch_size, num_length, input_dim].
"""
chex.assert_equal_shape((shift, scale, gate))
chex.assert_equal_rank((input, shift))
cfg = self.config
remat_pt1 = "linear1_0"
remat_pt2 = "linear2"
Expand All @@ -325,7 +346,7 @@ def forward(self, *, input: Tensor, shift: Tensor, scale: Tensor, gate: Tensor)
x = self.postnorm(x)

x = self.dropout2(x)
x = x * jnp.expand_dims(gate, 1)
x = x * gate
x += input
return x

Expand Down Expand Up @@ -389,12 +410,12 @@ def forward(

Args:
input: input tensor with shape [batch_size, num_length, target_dim].
shift: If provided, shifting the norm tensor with shape [batch_size, target_dim] and
scale should be provided.
scale: If provided, scaling the norm tensor with shape [batch_size, target_dim] and
shift should be provided.
shift: If provided, shifting the norm tensor with shape [batch_size, 1|num_length,
target_dim] and scale should be provided.
scale: If provided, scaling the norm tensor with shape [batch_size, 1|num_length,
target_dim] and shift should be provided.
gate: If provided, applying before the residual addition with shape
[batch_size, target_dim].
[batch_size, 1|num_length, target_dim].
attention_logit_biases: Optional Tensor representing the self attention biases.

Returns:
Expand Down Expand Up @@ -426,7 +447,7 @@ def forward(
x = self.postnorm(x)

if gate is not None:
x = x * jnp.expand_dims(gate, 1)
x = x * gate

output = input + x
return output
Expand Down Expand Up @@ -463,12 +484,12 @@ def extend_step(
results of previous attentions, and index used for fast decoding. Contains
"attention" cached states.
target: target tensor with shape [batch_size, step_length, target_dim].
shift: If provided, shifting the norm tensor with shape [batch_size, target_dim] and
scale should be provided.
scale: If provided, scaling the norm tensor with shape [batch_size, target_dim] and
shift should be provided.
shift: If provided, shifting the norm tensor with shape [batch_size, 1|num_length,
target_dim] and scale should be provided.
scale: If provided, scaling the norm tensor with shape [batch_size, 1|num_length,
target_dim] and shift should be provided.
gate: If provided, applying before the residual addition with shape
[batch_size, target_dim].
[batch_size, 1|num_length, target_dim].

Returns:
A tuple (cached_states, output):
Expand Down Expand Up @@ -504,7 +525,7 @@ def extend_step(
x = self.postnorm(x)

if gate is not None:
x = x * jnp.expand_dims(gate, 1)
x = x * gate

output = target + x
return dict(attention=attn_states), output
Expand Down Expand Up @@ -542,8 +563,8 @@ def forward(self, *, input: Tensor, condition: Tensor) -> Tensor:

Args:
input: input tensor with shape [batch_size, num_length, input_dim].
condition: tensor with shape [batch_size, input_dim] for generating
layer norm shift, scale, and gate.
condition: tensor with shape [batch_size, input_dim] or [batch_size, num_length,
input_dim] for generating layer norm shift, scale, and gate.

Returns:
A tensor with shape [batch_size, num_length, input_dim].
Expand Down Expand Up @@ -584,8 +605,8 @@ def extend_step(
results of previous attentions, and index used for fast decoding. Contains
"attention" cached states.
target: target tensor with shape [batch_size, step_length, input_dim].
condition: tensor with shape [batch_size, input_dim] for generating
layer norm shift, scale, and gate.
condition: tensor with shape [batch_size, input_dim] or [batch_size, step_length,
input_dim] for generating layer norm shift, scale, and gate.

Returns:
A tuple (cached_states, output):
Expand Down Expand Up @@ -639,8 +660,8 @@ def forward(self, *, input: Tensor, condition: Tensor) -> Tensor:

Args:
input: input tensor with shape [batch_size, num_length, input_dim].
condition: tensor with shape [batch_size, input_dim] for generating
layer norm shift and scale.
condition: tensor with shape [batch_size, input_dim] or [batch_size, num_length,
input_dim] for generating layer norm shift and scale.

Returns:
A tensor with shape [batch_size, num_length, output_dim].
Expand Down
Loading
Loading