-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow external positions to be inputed in RoPE embedding layer #926
Conversation
Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`.
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config): | |||
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding. | |||
theta: float = 10000.0 # The scale of base frequency. | |||
|
|||
def forward(self, positions: Tensor) -> Tensor: | |||
def default_query_positions(self, max_seq_len: int) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def default_query_positions(self, max_seq_len: int) -> Tensor: | |
def _default_query_positions(self, max_seq_len: int) -> Tensor: |
Users should pass max_seq_len
rather than calling this method publicly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be situations that we want to access the default query positions from the outside, such as getting the default positions and computing the positions based on it before passing it to forward
. Therefore we want to keep this public.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect subclasses to override this method?
If not, it will be more readable for callers to call jnp.arange
directly given the implementation is only one line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we do expect subclasses to override this method. One example is that we might want to specify the rotation start index and rotation end index for this embedding class. And the default embedding positions will be different then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Given this, it's reasonable to consolidate the position computation logic in this class.
Co-authored-by: Mark Lee <[email protected]>
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config): | |||
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding. | |||
theta: float = 10000.0 # The scale of base frequency. | |||
|
|||
def forward(self, positions: Tensor) -> Tensor: | |||
def default_query_positions(self, max_seq_len: int) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect subclasses to override this method?
If not, it will be more readable for callers to call jnp.arange
directly given the implementation is only one line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
@@ -1216,18 +1216,37 @@ class Config(BaseLayer.Config): | |||
dim: Required[int] = REQUIRED # The dimensionality of the positional embedding. | |||
theta: float = 10000.0 # The scale of base frequency. | |||
|
|||
def forward(self, positions: Tensor) -> Tensor: | |||
def default_query_positions(self, max_seq_len: int) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Given this, it's reasonable to consolidate the position computation logic in this class.
Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after
i_proj
. Unlike the implementation of currentRoFormerQKVLinear
, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported byMultiheadAttention
.