-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove prims.embedding and prims.embedding_backward #1689
base: main
Are you sure you want to change the base?
Changes from all commits
f77c202
979f872
97ca50b
8688844
ec0fe4f
1d76408
5e52ea7
b8bb32c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4710,18 +4710,21 @@ def embedding( | |
) -> TensorLike: | ||
# TODO: add embedding_renorm_ so we can remove embedding prim | ||
# NOTE: padding_idx has impact on backward and is not supported by take | ||
if max_norm is not None or padding_idx is not None: | ||
padding_idx = padding_idx if padding_idx is not None else -1 | ||
return prims.embedding( | ||
a, | ||
weight, | ||
padding_idx=padding_idx, | ||
max_norm=max_norm, | ||
norm_type=norm_type, | ||
scale_grad_by_freq=scale_grad_by_freq, | ||
sparse=sparse, | ||
if max_norm is not None: | ||
IvanYashchuk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# See https://github.com/Lightning-AI/lightning-thunder/issues/1699 | ||
raise NotImplementedError( | ||
"max_norm argument is currently not support. Please create an issue detailing your use case." | ||
) | ||
|
||
if padding_idx is not None: | ||
padding_idx = padding_idx if padding_idx is not None else -1 | ||
|
||
utils.check(a.dtype == dtypes.int64, lambda: f"Expected a.dtype={a.dtype} to be int64") | ||
utils.check(weight.ndim == 2, lambda: f"Expected weight (weight.shape={weight.shape} to be a matrix)") | ||
shape = list(a.shape) | ||
shape.append(weight.shape[1]) | ||
return TensorProxy(like=weight, shape=shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe want to add a comment for the direct construction of a tensorproxy here instead of calling an operator that would create one -- I'm curious, too |
||
|
||
# padding_idx / sparse not used by forward | ||
|
||
if a.ndim == 1: | ||
|
@@ -4733,10 +4736,12 @@ def embedding( | |
return reshape(flatten_output, output_shape) | ||
|
||
|
||
# TODO Once we have fusible index_put we can implement it using primitives | ||
# For now we just use the PyTorch implementation | ||
@torchsymbol(torch.ops.aten.embedding_backward) | ||
def embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse): | ||
result = prims.embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse) | ||
return result | ||
shape = (num_weights, grad.shape[-1]) | ||
return TensorProxy(shape=shape, device=grad.device, dtype=grad.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question here about directly constructing and return a TensorProxy over calling some op that would create one I guess I'm a little confused about operations in the torch namespace directly constructing TensorProxies, as that seems more like a primitives thing? Curious to hear your thoughts |
||
|
||
|
||
@torchsymbol(torch.nn.functional.one_hot, id="torch.nn.functional.one_hot", is_method=False) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
max_norm
argument was not supported. I don't know the reason for this. I think it should work and I'll update OpInfo samples to test it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_norm
argument when active modifies theweight
argument inplace. That's probably the reason this argument raised NotImplementedError. I'll keep the behavior as is in this PR.Created an issue for tracking completion #1699.