Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove prims.embedding and prims.embedding_backward #1689

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
35 changes: 0 additions & 35 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,6 @@ class PrimIDs(Enum):
MATMUL = auto()
# NN prims (Experimental!)
CONVOLUTION = auto()
EMBEDDING = auto()
EMBEDDING_BACKWARD = auto()
LINEAR = auto()
PAD = auto()
# Memory access methods
Expand Down Expand Up @@ -4082,39 +4080,6 @@ def maybe_expand_seq(seq, ndim):
)


def embedding_meta(
a: TensorProxy, /, weight, *, padding_idx=-1, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
) -> TensorProxy:
# TODO: canonicalize and validating padding idx with weight.shape[0]

if max_norm is not None:
raise NotImplementedError
Comment on lines -4090 to -4091
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max_norm argument was not supported. I don't know the reason for this. I think it should work and I'll update OpInfo samples to test it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_norm argument when active modifies the weight argument inplace. That's probably the reason this argument raised NotImplementedError. I'll keep the behavior as is in this PR.
Created an issue for tracking completion #1699.


utils.check(a.dtype == dtypes.int64, lambda: f"Expected a.dtype={a.dtype} to be int64")
utils.check(weight.ndim == 2, lambda: f"Expected weight (weight.shape={weight.shape} to be a matrix)")

shape = list(a.shape)
shape.append(weight.shape[1])

return TensorProxy(like=weight, shape=shape)


embedding = make_prim(PrimIDs.EMBEDDING, "embedding", meta=embedding_meta)


# TODO Update this so it's not a prim
# TODO Add annotations
# TODO Review requires_grad=False -- what about double backward?
# TODO Once we have fusible index_put we can implement it using primitives
# For now we just use the PyTorch implementation
def embedding_backward_meta(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse):
shape = (num_weights, grad.shape[-1])
return TensorProxy(shape=shape, device=grad.device, dtype=grad.dtype, requires_grad=False)


embedding_backward = make_prim(PrimIDs.EMBEDDING_BACKWARD, "embedding_backward", meta=embedding_backward_meta)


def copy__meta(
copy_from: TensorProxy,
copy_to: TensorProxy,
Expand Down
7 changes: 6 additions & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None:
nsbsyms: list[BoundSymbol] = []
sbsym: BoundSymbol
for sbsym in bsym.subsymbols:
if len(sbsym.subsymbols) == 0 and not sbsym.sym.is_prim:
# We don't remove symbols that are explicitly marked as not to be removed
if has_tags(sbsym, {prims.OpTags.DONT_DCE}):
nsbsyms.append(sbsym)
continue
# If all outputs are in the arguments, we eliminate the symbol
if all(o in sbsym.flat_variableified_proxy_args for o in sbsym.flat_variableified_proxy_outs):
continue
# if all outputs are constants, we elmininate the subsymbol
if not has_tags(bsym, {prims.OpTags.DONT_DCE}) and not any(
Expand Down
23 changes: 0 additions & 23 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,29 +1348,6 @@ def _matmul_prim_grad(a: TensorProxy, b: TensorProxy, /) -> TensorProxy:
#


def _embedding_prim_grad(
a: TensorProxy, /, weight, *, padding_idx=-1, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
) -> TensorProxy:
fwd = prims.embedding(
a,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
)

g = get_grad(fwd)
a_grad = prims.embedding_backward(g, a, weight.shape[0], padding_idx, scale_grad_by_freq, sparse)
put_grad(a, a_grad)

return fwd


register_grad(pids.EMBEDDING, _embedding_prim_grad)


def _maximum_grad(a: TensorProxy, b: TensorProxy, /):
fwd = prims.maximum(a, b)

Expand Down
1 change: 0 additions & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2662,5 +2662,4 @@ def embedding(
return fd.ops.embedding_fwd(*nv_inputs)


register_supported(PrimIDs.EMBEDDING, embedding, _embedding_check)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
register_supported(ltorch.embedding, embedding, _embedding_check)
2 changes: 0 additions & 2 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,8 +1722,6 @@ def _pad_prim_impl(


_register_implementation(prims.convolution, checker=_always_executable, execution_transform=_convolution_transform)
_register_implementation(prims.embedding, embedding, checker=_always_executable)
_register_implementation(prims.embedding_backward, embedding_backward, checker=_always_executable)
_register_implementation(prims.linear, linear, checker=_always_executable)

_register_implementation(ltorch.baddbmm, baddbmm, checker=_always_executable)
Expand Down
29 changes: 17 additions & 12 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4710,18 +4710,21 @@ def embedding(
) -> TensorLike:
# TODO: add embedding_renorm_ so we can remove embedding prim
# NOTE: padding_idx has impact on backward and is not supported by take
if max_norm is not None or padding_idx is not None:
padding_idx = padding_idx if padding_idx is not None else -1
return prims.embedding(
a,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
if max_norm is not None:
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
# See https://github.com/Lightning-AI/lightning-thunder/issues/1699
raise NotImplementedError(
"max_norm argument is currently not support. Please create an issue detailing your use case."
)

if padding_idx is not None:
padding_idx = padding_idx if padding_idx is not None else -1

utils.check(a.dtype == dtypes.int64, lambda: f"Expected a.dtype={a.dtype} to be int64")
utils.check(weight.ndim == 2, lambda: f"Expected weight (weight.shape={weight.shape} to be a matrix)")
shape = list(a.shape)
shape.append(weight.shape[1])
return TensorProxy(like=weight, shape=shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe want to add a comment for the direct construction of a tensorproxy here instead of calling an operator that would create one -- I'm curious, too


# padding_idx / sparse not used by forward

if a.ndim == 1:
Expand All @@ -4733,10 +4736,12 @@ def embedding(
return reshape(flatten_output, output_shape)


# TODO Once we have fusible index_put we can implement it using primitives
# For now we just use the PyTorch implementation
@torchsymbol(torch.ops.aten.embedding_backward)
def embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse):
result = prims.embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse)
return result
shape = (num_weights, grad.shape[-1])
return TensorProxy(shape=shape, device=grad.device, dtype=grad.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here about directly constructing and return a TensorProxy over calling some op that would create one

I guess I'm a little confused about operations in the torch namespace directly constructing TensorProxies, as that seems more like a primitives thing? Curious to hear your thoughts



@torchsymbol(torch.nn.functional.one_hot, id="torch.nn.functional.one_hot", is_method=False)
Expand Down
Loading