Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow external positions to be inputed in RoPE embedding layer #926

Merged
merged 14 commits into from
Jan 27, 2025

Conversation

Firenze11
Copy link
Contributor

Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after i_proj. Unlike the implementation of current RoFormerQKVLinear, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by MultiheadAttention.

Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`.
@Firenze11 Firenze11 requested review from ruomingp, markblee and a team as code owners January 15, 2025 02:56
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config):
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding.
theta: float = 10000.0 # The scale of base frequency.

def forward(self, positions: Tensor) -> Tensor:
def default_query_positions(self, max_seq_len: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def default_query_positions(self, max_seq_len: int) -> Tensor:
def _default_query_positions(self, max_seq_len: int) -> Tensor:

Users should pass max_seq_len rather than calling this method publicly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be situations that we want to access the default query positions from the outside, such as getting the default positions and computing the positions based on it before passing it to forward. Therefore we want to keep this public.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect subclasses to override this method?

If not, it will be more readable for callers to call jnp.arange directly given the implementation is only one line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we do expect subclasses to override this method. One example is that we might want to specify the rotation start index and rotation end index for this embedding class. And the default embedding positions will be different then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Given this, it's reasonable to consolidate the position computation logic in this class.

axlearn/common/attention.py Outdated Show resolved Hide resolved
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config):
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding.
theta: float = 10000.0 # The scale of base frequency.

def forward(self, positions: Tensor) -> Tensor:
def default_query_positions(self, max_seq_len: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect subclasses to override this method?

If not, it will be more readable for callers to call jnp.arange directly given the implementation is only one line.

axlearn/common/attention.py Outdated Show resolved Hide resolved
Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

axlearn/common/attention.py Outdated Show resolved Hide resolved
axlearn/common/attention.py Show resolved Hide resolved
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config):
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding.
theta: float = 10000.0 # The scale of base frequency.

def forward(self, positions: Tensor) -> Tensor:
def default_query_positions(self, max_seq_len: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Given this, it's reasonable to consolidate the position computation logic in this class.

@Firenze11 Firenze11 enabled auto-merge January 27, 2025 16:51
@Firenze11 Firenze11 added this pull request to the merge queue Jan 27, 2025
Merged via the queue into apple:main with commit a854738 Jan 27, 2025
6 checks passed
@Firenze11 Firenze11 deleted the rope_emb_pos branch January 27, 2025 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants