diff --git a/experimental/torch_xla2/docs/support_a_new_model.md b/experimental/torch_xla2/docs/support_a_new_model.md index 07a2be41e748..8578861e4cf9 100644 --- a/experimental/torch_xla2/docs/support_a_new_model.md +++ b/experimental/torch_xla2/docs/support_a_new_model.md @@ -14,144 +14,124 @@ torch_xla2, because: 2. Some op it needs is implemented incorrectly 3. There are some non-torch-op code that interacts with torch_xla2 in a non-friendly matter. -Here we present few steps to attempt to fix the related issues. +Here we present few steps to attempt to fix the related issues. Using dlrm model as +example. -# Step 1. Attempt to run the model +This assumes that you already installed torch_xla2 with `pip install -e .` locally. +Following the instructions in [README](../README.md) -To run a model under torch_xla2, the first step is to -instantiate the model and run it under normal torch. -This usually means eager mode torch CPU. (NOTE: for large - models, it's recommended to make a model of equal architecture but smaller, by setting fewer layers / dim sizes; OR, use GPU -so that it can run reasonably fast). -In this example, we will use `BERT_pytorch` model from -torchbench. +### Get torchbench scripts -## Install torchbench and instantiate a the model +Following the instructions in https://github.com/pytorch-tpu/run_torchbench -```bash -git clone https://github.com/pytorch/benchmark.git torchbench -cd torchbench -pip install torchvision torchaudio -pip install -e . -``` -Now, torchbench is installed, now we need to download -the model. -``` -python install.py BERT_pytorch -``` +### Run script from run_torchbench: -NOTE: if you run `python install.py` without positional args -it will download ALL the 100+ models which can take sometime. +```bash +(xla2) hanq-macbookpro:run_torchbench hanq$ python models/dlrm.py +Traceback (most recent call last): + File "/Users/hanq/git/qihqi/run_torchbench/models/dlrm.py", line 16, in + module = importlib.import_module(model_name) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/importlib/__init__.py", line 126, in import_module + return _bootstrap._gcd_import(name[level:], package, level) + File "", line 1050, in _gcd_import + File "", line 1027, in _find_and_load + File "", line 1006, in _find_and_load_unlocked + File "", line 688, in _load_unlocked + File "", line 883, in exec_module + File "", line 241, in _call_with_frames_removed + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/dlrm/__init__.py", line 15, in + from .tricks.qr_embedding_bag import QREmbeddingBag +ModuleNotFoundError: No module named 'torchbenchmark.models.dlrm.tricks' +``` -Now, let's verify that the model is there by importing it in python. +Turns out I forgot to run `python install.py dlrm` in the benchmarks folder (cloned from pytorch/benchmark) -```python -import torchbenchmark.models.BERT_pytorch -model, sample_inputs = torchbenchmark.models.BERT_pytorch.Model( - test='eval', device='cpu' -) +### Fixing missing ops: -print(model(*sample_inputs)) +Rerunning: +```bash +(xla2) hanq-macbookpro:run_torchbench hanq$ python models/dlrm.py +Traceback (most recent call last): + File "/Users/hanq/git/qihqi/run_torchbench/models/dlrm.py", line 28, in + print(model(*example)) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl + return forward_call(*args, **kwargs) + File "/Users/hanq/git/qihqi/run_torchbench/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 355, in forward + return self.sequential_forward(dense_x, lS_o, lS_i) + File "/Users/hanq/git/qihqi/run_torchbench/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 367, in sequential_forward + ly = self.apply_emb(lS_o, lS_i, self.emb_l) + File "/Users/hanq/git/qihqi/run_torchbench/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 308, in apply_emb + V = E(sparse_index_group_batch, sparse_offset_group_batch) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl + return forward_call(*args, **kwargs) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 390, in forward + return F.embedding_bag(input, self.weight, offsets, + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2360, in embedding_bag + return handle_torch_function( + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/overrides.py", line 1619, in handle_torch_function + result = mode.__torch_function__(public_api, types, args, kwargs) + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 215, in __torch_function__ + return func(*args, **(kwargs or {})) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2451, in embedding_bag + ret, _, _, _ = torch.embedding_bag( + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 230, in __torch_dispatch__ + return self.env.dispatch(func, types, args, kwargs) + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 310, in dispatch + raise OperatorNotFound( +torch_xla2.tensor.OperatorNotFound: Operator with name aten::_embedding_bag has no lowering ``` -If the above succeeds, then the model is ready. +Now let's implement this op. -# Attempt to run the model in torchxla2 +Few tricks while implementing the ops: -To run the model in torch_xla2, we need to do 2 things: -1. Move the model's weight to XLA device (i.e. XLA tensors) -2. Move the sample_inputs to XLA device (i.e. XLA tensors) +1. Feel free to edit the script `models/dlrm.py` while debugging. +2. Useful options to set `env.config.debug_print_each_op = True` will print out each + op that goes through the dispatcher. +3. Set `env.config.debug_accuracy_for_each_op = True` will in addition of running Jax + op, it also runs it again in Torch CPU. Then it diffs the result. If the diff is too + large, then it drops you into pdb for inspection. +4. After inspecting input / output / shapes of the op, maybe it's enough hint for + you to fix this op. Or, if it's not, then it's adviced to save the inputs / outputs + and write a unit test for it and iterate on that. Usually a unit test is faster + to iterate than running a whole model. -The API for the above is the `to_xla` method on `Environment` class. -To get the current environment, one can use `torch_xla2.default_env()`. - -i.e. - -```python -xla_env = torch_xla2.default_env() -model2 = xla_env.to_xla(model) -sample_inputs = xla_env.to_xla(sample_inputs) -with xla_env: - print(model2(*sample_inputs)) -``` +After finishing `embedding_bag` badly, I reached the next op -You might get something like this: ```bash -Traceback (most recent call last): - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py", line 13, in - benchmark = benchmark_cls(test="eval", device = "cpu") # test = train or eval device = cuda or cpu - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/util/model.py", line 39, in __call__ - obj = type.__call__(cls, *args, **kwargs) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/__init__.py", line 174, in __init__ - bert = BERT( - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/bert.py", line 30, in __init__ - self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/bert.py", line 24, in __init__ - self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/token.py", line 6, in __init__ - super().__init__(vocab_size, embed_size, padding_idx=0) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 145, in __init__ - self.reset_parameters() - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 154, in reset_parameters - init.normal_(self.weight) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/init.py", line 172, in normal_ - return torch.overrides.handle_torch_function( - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/overrides.py", line 1619, in handle_torch_function - result = mode.__torch_function__(public_api, types, args, kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 210, in __torch_function__ + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 390, in forward + return F.embedding_bag(input, self.weight, offsets, + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2451, in embedding_bag + ret, _, _, _ = torch.embedding_bag( + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 124, in __torch_dispatch__ return func(*args, **(kwargs or {})) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/init.py", line 175, in normal_ - return _no_grad_normal_(tensor, mean, std, generator) - File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/init.py", line 20, in _no_grad_normal_ - return tensor.normal_(mean, std, generator=generator) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 224, in __torch_dispatch__ + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__ + return self_._op(*args, **kwargs) + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 212, in __torch_function__ + return func(*args, **(kwargs or {})) + File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__ + return self_._op(*args, **kwargs) + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 227, in __torch_dispatch__ return self.env.dispatch(func, types, args, kwargs) - File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 297, in dispatch + File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 308, in dispatch raise OperatorNotFound( -torch_xla2.tensor.OperatorNotFound: Operator with name aten::normal_ has no lowering +torch_xla2.tensor.OperatorNotFound: Operator with name aten::_embedding_bag_forward_only has no lowering ``` -if the issue is with operators. -Sometimes it's helpful to see how did this operator is called. -Note that, many times, an operator being called can also be -unnexpected. +Turns out, that is the same operator. so adding the @op(torch.ops.aten._embedding_bag_forward_only) +on top of the same op works. -We can turn on logging with -`xla_env.config.debug_print_each_op` and it will print each operator that is being run. +Now the resulting PR is: https://github.com/pytorch/xla/pull/7583 -The logs looks like this: +After this `python models/dlrm.py` runs. -``` -2024-06-16 15:03:13,726 - root - DEBUG - FUNCTION: aten::view -2024-06-16 15:03:13,726 - root - DEBUG - FUNCTION: aten::gelu -2024-06-16 15:03:13,729 - root - DEBUG - FUNCTION: aten::view -2024-06-16 15:03:13,729 - root - DEBUG - FUNCTION: aten::t -2024-06-16 15:03:13,729 - root - DEBUG - FUNCTION: transpose -2024-06-16 15:03:13,729 - root - DEBUG - DISPATCH: aten::transpose.int -2024-06-16 15:03:13,730 - root - DEBUG - FUNCTION: permute -2024-06-16 15:03:13,730 - root - DEBUG - DISPATCH: aten::permute -2024-06-16 15:03:13,731 - root - DEBUG - FUNCTION: aten::addmm -2024-06-16 15:03:13,737 - root - DEBUG - FUNCTION: aten::view -2024-06-16 15:03:13,739 - root - DEBUG - FUNCTION: aten::add.Tensor -2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::slice.Tensor -2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::select.int -2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::t -2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: transpose -2024-06-16 15:03:13,740 - root - DEBUG - DISPATCH: aten::transpose.int -2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: permute -2024-06-16 15:03:13,740 - root - DEBUG - DISPATCH: aten::permute -2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::addmm -2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: aten::_log_softmax -2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: aten::view -2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: aten::t -2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: transpose -2024-06-16 15:03:13,741 - root - DEBUG - DISPATCH: aten::transpose.int -2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: permute -2024-06-16 15:03:13,741 - root - DEBUG - DISPATCH: aten::permute -2024-06-16 15:03:13,764 - root - DEBUG - FUNCTION: aten::addmm -2024-06-16 15:03:13,788 - root - DEBUG - FUNCTION: aten::view -2024-06-16 15:03:13,790 - root - DEBUG - FUNCTION: aten::_log_softmax -``` \ No newline at end of file +NOTE: +The _embedding_bag implementation is actually very crude, just sufficient to make +the model pass. diff --git a/experimental/torch_xla2/torch_xla2/config.py b/experimental/torch_xla2/torch_xla2/config.py index e7baee5b563e..2f971f13fa44 100644 --- a/experimental/torch_xla2/torch_xla2/config.py +++ b/experimental/torch_xla2/torch_xla2/config.py @@ -5,6 +5,7 @@ class Configuration: debug_print_each_op: bool = False debug_accuracy_for_each_op: bool = False + debug_mixed_tensor: bool = False use_int32_for_index: bool = False # Flash attention diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 3d93c7735a49..71edfa0b33f7 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1,9 +1,12 @@ """Torch ops implemented using jax.""" +import functools import sys import jax from jax import numpy as jnp +from jax.experimental.sparse import BCOO + import numpy as np import torch from torch_xla2.ops import ops_registry @@ -294,6 +297,75 @@ def _aten_embedding(a, w, padding_idx=-1): return jnp.take(a, w, axis=0) +#- func: _embedding_bag_forward_only( +# Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, +# int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) +@op(torch.ops.aten._embedding_bag) +@op(torch.ops.aten._embedding_bag_forward_only) +def _aten__embedding_bag( + weight, + indices, + offsets=None, + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1): + """Jax implementation of the PyTorch _embedding_bag function. + + Args: + weight: The learnable weights of the module of shape (num_embeddings, embedding_dim). + indices: A LongTensor containing the indices to extract. + offsets: A LongTensor containing the starting offset of each bag. + scale_grad_by_freq: Whether to scale gradients by the inverse of frequency of the words in the mini-batch. + mode: 0 = "sum", 1 = "mean" or 2 = "max" + sparse: Whether the gradients with respect to weight should be a sparse tensor. + per_sample_weights: If given, each embedding vector is weighted by per_sample_weights + include_last_offset: Whether to include the last offset as a valid bag. + padding_idx: If specified, the entries at padding_idx do not contribute to the gradient. + + Returns: + A tuple of (output, offset2bag, bag_size, max_indices). + """ + embedded = _aten_embedding(weight, indices, padding_idx) + + def static_dynamic_slice(x, start, size): + return jax.lax.dynamic_slice_in_dim(x, start, size) + + + # TODO not jittable + def reduce_by_segment(start, size, x, reducer): + res = [] + for starti, sizei in zip(start, size): + res.append(reducer(static_dynamic_slice(x, starti, sizei), axis=0)) + return jnp.stack(res) + + def segsum(x, offsets, reducer): + start, end = offsets, jnp.concat([offsets[1:], jnp.array([x.shape[0]])]) + return reduce_by_segment(start, end - start, x, reducer) + + if mode not in (0, 1, 2): + raise ValueError("Invalid mode. Please choose 0 (sum) or 1 (mean).") + if mode == 0: # sum + reducer = jnp.sum + elif mode == 1: # mean + reducer = jnp.mean + elif mode == 2: # max + reducer = jnp.max + + if indices.ndim == 1 and offsets is not None: + output = segsum(embedded, offsets, reducer) + else: + output = reducer(embedded, axis=1) + + # TODO: return output, offset2bag, bag_size, max_indices + return output, None, None, None + + + + + @op(torch.ops.aten.rsqrt) def _aten_rsqrt(x): if isinstance(x, int): diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index f87f21f351e9..5b6631b1684a 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -1,4 +1,4 @@ -import logging +import sys import contextlib from typing import Optional import jax @@ -45,8 +45,9 @@ def j2t_dtype(dtype): @contextlib.contextmanager -def log_nested(message): - logging.debug((' ' * log_nested.level) + message) +def log_nested(env, message): + if env.config.debug_print_each_op: + print((' ' * log_nested.level) + message, file=sys.stderr) log_nested.level += 1 yield log_nested.level -= 1 @@ -165,15 +166,11 @@ def mutate(self, slice, new_content): self._orig_tensor._elem = self._orig_tensor.at[slice].set(new_content) -_DEBUG_ACCURACY = False - def debug_accuracy(func, args, kwargs, current_output): - if not _DEBUG_ACCURACY: - return True - args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only( torch.Tensor, lambda x: j2t(x._elem), (args, kwargs, current_output)) + expected_out = func(*args_torch, **kwargs_torch) flattened_current_out, _ = torch_pytree.tree_flatten(out_torch) @@ -183,7 +180,7 @@ def debug_accuracy(func, args, kwargs, current_output): if ex.dtype != real.dtype: ex = ex.to(real.dtype) try: - if (_DEBUG_ACCURACY and isinstance(ex, torch.Tensor) and + if (isinstance(ex, torch.Tensor) and not torch.allclose(ex, real, atol=1e-3, equal_nan=True)): import pdb @@ -207,7 +204,7 @@ def __torch_function__(self, types, args=(), kwargs=None) -> torch.Tensor: - with log_nested(f'FUNCTION: {_name_of_func(func)}'): + with log_nested(self.env, f'FUNCTION: {_name_of_func(func)}'): try: return self.env.dispatch(func, types, args, kwargs) except OperatorNotFound: @@ -221,7 +218,7 @@ def __init__(self, env): self.env = env def __torch_dispatch__(self, func, types, args=(), kwargs=None): - with log_nested(f'DISPATCH: {_name_of_func(func)}'): + with log_nested(self.env, f'DISPATCH: {_name_of_func(func)}'): if isinstance(func, torch._ops.OpOverloadPacket): with self: return func(*args, **kwargs) @@ -289,6 +286,7 @@ def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None): return jax.random.key(next_key) def dispatch(self, func, types, args, kwargs): + kwargs = kwargs or {} # If the func don't act on XLATensor2, and is not a tensor constructor, @@ -310,8 +308,14 @@ def dispatch(self, func, types, args, kwargs): raise OperatorNotFound( f'Operator with name {_name_of_func(func)} has no lowering') - if op.is_jax_function: - args, kwargs = self.t2j_iso((args, kwargs)) + old_args, old_kwargs = args, kwargs + try: + if op.is_jax_function: + args, kwargs = self.t2j_iso((args, kwargs)) + except AssertionError: + if self.config.debug_mixed_tensor: + import pdb; pdb.set_trace() + if op.needs_env: kwargs['env'] = self @@ -322,8 +326,8 @@ def dispatch(self, func, types, args, kwargs): if op.is_jax_function: res = self.j2t_iso(res) - #if self.config.debug_accuracy_for_each_op: - # debug_accuracy(func, args, kwargs, res) + if self.config.debug_accuracy_for_each_op: + debug_accuracy(func, old_args, old_kwargs, res) return res def __enter__(self): @@ -358,7 +362,7 @@ def to_xla(self, torchvalues): def t2j_iso(self, torchtensors): def to_jax(x): - assert isinstance(x, XLATensor2) + assert isinstance(x, XLATensor2), f'Expect a XLATensor2 but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor' return x.jax() return torch_pytree.tree_map_only(torch.Tensor, to_jax, torchtensors)