Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why training is ~3x slower than Swin? #15

Open
rayleizhu opened this issue Sep 12, 2022 · 2 comments
Open

Why training is ~3x slower than Swin? #15

rayleizhu opened this issue Sep 12, 2022 · 2 comments

Comments

@rayleizhu
Copy link

Thanks for open sourcing this great work. While trying the code, I found the training speed is ~3x slower than Swin Transformer. For example, for quadtree-b2 which has similar FLOPs as Swin-T, training takes ~2.5s per batch. And it is even slower (3s/batch) when I align its macro design (depths, embedding dims, etc.) with Swin-T.

Can you give some insights to account for this scenario?

@rayleizhu
Copy link
Author

Personally, I have some guesses for the slowness:

  1. the recursive assembly process of quadtree attention
  2. the low computation intensity of the quadtree attention due to scattered keys and values

Am I correct? Between the above two, which one has more effect? Are there any other possible reasons according to your experience?

@rayleizhu rayleizhu changed the title training is ~3x slower than Swin Why training is ~3x slower than Swin? Sep 12, 2022
@Tangshitao
Copy link
Owner

Not exactly. There are 2 reasons: 1) we implement the quadtree attention with raw cuda without much optimization. We expect a speedup if implemented with torch.geometry. 2) The sparsity nature of quadtree attention make it unfriendly to hardware. This cannot be solved from code level.
I suggest that the easiest solution is to reduce top K, so you can achieve significant speedup without much performance loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants