Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Flux Transformer RoPE use a custom IREE kernel #871

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sogartar
Copy link
Contributor

We assume that the custom kernel would yield better performance instead of using PyTorch ops.

@sogartar
Copy link
Contributor Author

To match the required dimension order, axis permutation is performed.
I am not sure if this can be optimized out. Needs some investigation of
the compilation passes.

This change introduces a significant (10X) deterioration in performance.
Baseline performance is 552 ms.
With this change it is 6131 ms.

@sogartar sogartar force-pushed the flux-transformer-rope-with-kernel branch 2 times, most recently from 4ff5c3e to 3b751f4 Compare January 28, 2025 23:40
@sogartar sogartar marked this pull request as ready for review January 31, 2025 15:12
@sogartar
Copy link
Contributor Author

The large performance problem has been addressed by iree-org/iree#19822.



def compute_rotary_embedding_table(
positions: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: would it make more sense to just rename _compute_rotary_embedding_table?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to use this function outside of the class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, i mean, instead of:

  • copying compute_rotary_embedding_table so we can use it outside the class
  • make old _compute_rotary_embedding_table redirect to compute_rotary_embedding_table

Just

  • rename _compute_rotary_embedding_table to compute_rotary_embedding_table and use it outside the class
  • change all referenes to _compute_rotary_embedding_table to use compute_rotary_embedding_table instead

The latter requires an IDE and is slightly more work, but does not leave a stub/redirect function behind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class variant got updated to handle the hugging face case so it has more hair now.

We assume that the custom kernel would yield better performance instead
of using PyTorch ops.
@sogartar sogartar force-pushed the flux-transformer-rope-with-kernel branch from 3b751f4 to 48e1b04 Compare January 31, 2025 20:00
@sogartar
Copy link
Contributor Author

sogartar commented Feb 1, 2025

This PR is waiting on iree-org/iree#19829.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants