-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove prims.embedding and prims.embedding_backward #1689
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
This is interesting. Isn't this regressing support for calling |
if max_norm is not None: | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The max_norm
argument was not supported. I don't know the reason for this. I think it should work and I'll update OpInfo samples to test it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_norm
argument when active modifies the weight
argument inplace. That's probably the reason this argument raised NotImplementedError. I'll keep the behavior as is in this PR.
Created an issue for tracking completion #1699.
Pull request was converted to draft
A wait then we might want to change this too lightning-thunder/thunder/executors/nvfuserex_impl.py Lines 2636 to 2637 in 52ee541
cc. @Priya2698 for #1674 |
We don't need
prims.embedding
andprims.embedding_backward
because we can simply usethunder.torch.embedding
andthunder.torch.embedding_backward
as our transformation targets. When we added these primitives we didn't have a good idea how multi-level symbols can be targeted.With this PR
thunder.torch.embedding
is transformed directly totorch.nn.functional.embedding
for execution. While before this PRthunder.torch.embedding
is decomposed intoprims.embedding
which is then transformed intotorch.nn.functional.embedding
(and similarly forembedding_backward
).Testing: