Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove prims.embedding and prims.embedding_backward #1689

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Jan 24, 2025

We don't need prims.embedding and prims.embedding_backward because we can simply use thunder.torch.embedding and thunder.torch.embedding_backward as our transformation targets. When we added these primitives we didn't have a good idea how multi-level symbols can be targeted.

With this PR thunder.torch.embedding is transformed directly to torch.nn.functional.embedding for execution. While before this PR thunder.torch.embedding is decomposed into prims.embedding which is then transformed into torch.nn.functional.embedding (and similarly for embedding_backward).

Testing:

pytest thunder/tests/test_grad.py -k "embedding" -vvv
pytest thunder/tests/test_ops.py -k "embedding" -vvv

@IvanYashchuk IvanYashchuk enabled auto-merge (squash) January 24, 2025 16:35
@mruberry
Copy link
Collaborator

mruberry commented Jan 24, 2025

This is interesting.

Isn't this regressing support for calling torch.embedding with the max_norm argument, though?

Comment on lines -4087 to -4088
if max_norm is not None:
raise NotImplementedError
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max_norm argument was not supported. I don't know the reason for this. I think it should work and I'll update OpInfo samples to test it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_norm argument when active modifies the weight argument inplace. That's probably the reason this argument raised NotImplementedError. I'll keep the behavior as is in this PR.
Created an issue for tracking completion #1699.

@IvanYashchuk IvanYashchuk marked this pull request as draft January 24, 2025 16:49
auto-merge was automatically disabled January 24, 2025 16:49

Pull request was converted to draft

@riccardofelluga
Copy link
Collaborator

riccardofelluga commented Jan 24, 2025

A wait then we might want to change this too

register_supported(PrimIDs.EMBEDDING, embedding, _embedding_check)
register_supported(ltorch.embedding, embedding, _embedding_check)

cc. @Priya2698 for #1674

@IvanYashchuk IvanYashchuk marked this pull request as ready for review January 27, 2025 19:26
@IvanYashchuk IvanYashchuk enabled auto-merge (squash) January 27, 2025 19:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants