Skip to content

Commit

Permalink
Add documentation on fixing a torchbench model on torchxla2 (pytorch#…
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Jun 28, 2024
1 parent 5e51ca2 commit 6d375f2
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 131 deletions.
210 changes: 95 additions & 115 deletions experimental/torch_xla2/docs/support_a_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,144 +14,124 @@ torch_xla2, because:
2. Some op it needs is implemented incorrectly
3. There are some non-torch-op code that interacts with torch_xla2 in a non-friendly matter.

Here we present few steps to attempt to fix the related issues.
Here we present few steps to attempt to fix the related issues. Using dlrm model as
example.

# Step 1. Attempt to run the model
This assumes that you already installed torch_xla2 with `pip install -e .` locally.
Following the instructions in [README](../README.md)

To run a model under torch_xla2, the first step is to
instantiate the model and run it under normal torch.
This usually means eager mode torch CPU. (NOTE: for large
models, it's recommended to make a model of equal architecture but smaller, by setting fewer layers / dim sizes; OR, use GPU
so that it can run reasonably fast).

In this example, we will use `BERT_pytorch` model from
torchbench.
### Get torchbench scripts

## Install torchbench and instantiate a the model
Following the instructions in https://github.com/pytorch-tpu/run_torchbench

```bash
git clone https://github.com/pytorch/benchmark.git torchbench
cd torchbench
pip install torchvision torchaudio
pip install -e .
```
Now, torchbench is installed, now we need to download
the model.

```
python install.py BERT_pytorch
```
### Run script from run_torchbench:

NOTE: if you run `python install.py` without positional args
it will download ALL the 100+ models which can take sometime.
```bash
(xla2) hanq-macbookpro:run_torchbench hanq$ python models/dlrm.py
Traceback (most recent call last):
File "/Users/hanq/git/qihqi/run_torchbench/models/dlrm.py", line 16, in <module>
module = importlib.import_module(model_name)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 883, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/dlrm/__init__.py", line 15, in <module>
from .tricks.qr_embedding_bag import QREmbeddingBag
ModuleNotFoundError: No module named 'torchbenchmark.models.dlrm.tricks'
```

Now, let's verify that the model is there by importing it in python.
Turns out I forgot to run `python install.py dlrm` in the benchmarks folder (cloned from pytorch/benchmark)

```python
import torchbenchmark.models.BERT_pytorch

model, sample_inputs = torchbenchmark.models.BERT_pytorch.Model(
test='eval', device='cpu'
)
### Fixing missing ops:

print(model(*sample_inputs))
Rerunning:
```bash
(xla2) hanq-macbookpro:run_torchbench hanq$ python models/dlrm.py
Traceback (most recent call last):
File "/Users/hanq/git/qihqi/run_torchbench/models/dlrm.py", line 28, in <module>
print(model(*example))
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/hanq/git/qihqi/run_torchbench/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 355, in forward
return self.sequential_forward(dense_x, lS_o, lS_i)
File "/Users/hanq/git/qihqi/run_torchbench/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 367, in sequential_forward
ly = self.apply_emb(lS_o, lS_i, self.emb_l)
File "/Users/hanq/git/qihqi/run_torchbench/benchmark/torchbenchmark/models/dlrm/dlrm_s_pytorch.py", line 308, in apply_emb
V = E(sparse_index_group_batch, sparse_offset_group_batch)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 390, in forward
return F.embedding_bag(input, self.weight, offsets,
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2360, in embedding_bag
return handle_torch_function(
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/overrides.py", line 1619, in handle_torch_function
result = mode.__torch_function__(public_api, types, args, kwargs)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 215, in __torch_function__
return func(*args, **(kwargs or {}))
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2451, in embedding_bag
ret, _, _, _ = torch.embedding_bag(
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 230, in __torch_dispatch__
return self.env.dispatch(func, types, args, kwargs)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 310, in dispatch
raise OperatorNotFound(
torch_xla2.tensor.OperatorNotFound: Operator with name aten::_embedding_bag has no lowering
```
If the above succeeds, then the model is ready.
Now let's implement this op.
# Attempt to run the model in torchxla2
Few tricks while implementing the ops:
To run the model in torch_xla2, we need to do 2 things:
1. Move the model's weight to XLA device (i.e. XLA tensors)
2. Move the sample_inputs to XLA device (i.e. XLA tensors)
1. Feel free to edit the script `models/dlrm.py` while debugging.
2. Useful options to set `env.config.debug_print_each_op = True` will print out each
op that goes through the dispatcher.
3. Set `env.config.debug_accuracy_for_each_op = True` will in addition of running Jax
op, it also runs it again in Torch CPU. Then it diffs the result. If the diff is too
large, then it drops you into pdb for inspection.
4. After inspecting input / output / shapes of the op, maybe it's enough hint for
you to fix this op. Or, if it's not, then it's adviced to save the inputs / outputs
and write a unit test for it and iterate on that. Usually a unit test is faster
to iterate than running a whole model.
The API for the above is the `to_xla` method on `Environment` class.
To get the current environment, one can use `torch_xla2.default_env()`.

i.e.

```python
xla_env = torch_xla2.default_env()
model2 = xla_env.to_xla(model)
sample_inputs = xla_env.to_xla(sample_inputs)
with xla_env:
print(model2(*sample_inputs))
```
After finishing `embedding_bag` badly, I reached the next op
You might get something like this:
```bash
Traceback (most recent call last):
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/examples/torchbench_models/BERT_pytorch.py", line 13, in <module>
benchmark = benchmark_cls(test="eval", device = "cpu") # test = train or eval device = cuda or cpu
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/util/model.py", line 39, in __call__
obj = type.__call__(cls, *args, **kwargs)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/__init__.py", line 174, in __init__
bert = BERT(
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/bert.py", line 30, in __init__
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/bert.py", line 24, in __init__
self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torchbench-0.1-py3.10.egg/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/token.py", line 6, in __init__
super().__init__(vocab_size, embed_size, padding_idx=0)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 145, in __init__
self.reset_parameters()
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 154, in reset_parameters
init.normal_(self.weight)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/init.py", line 172, in normal_
return torch.overrides.handle_torch_function(
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/overrides.py", line 1619, in handle_torch_function
result = mode.__torch_function__(public_api, types, args, kwargs)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 210, in __torch_function__
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 390, in forward
return F.embedding_bag(input, self.weight, offsets,
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/functional.py", line 2451, in embedding_bag
ret, _, _, _ = torch.embedding_bag(
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 124, in __torch_dispatch__
return func(*args, **(kwargs or {}))
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/init.py", line 175, in normal_
return _no_grad_normal_(tensor, mean, std, generator)
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/nn/init.py", line 20, in _no_grad_normal_
return tensor.normal_(mean, std, generator=generator)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 224, in __torch_dispatch__
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__
return self_._op(*args, **kwargs)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 212, in __torch_function__
return func(*args, **(kwargs or {}))
File "/Users/hanq/homebrew/Caskroom/miniconda/base/envs/xla2/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__
return self_._op(*args, **kwargs)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 227, in __torch_dispatch__
return self.env.dispatch(func, types, args, kwargs)
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 297, in dispatch
File "/Users/hanq/git/qihqi/torch_xla/experimental/torch_xla2/torch_xla2/tensor.py", line 308, in dispatch
raise OperatorNotFound(
torch_xla2.tensor.OperatorNotFound: Operator with name aten::normal_ has no lowering
torch_xla2.tensor.OperatorNotFound: Operator with name aten::_embedding_bag_forward_only has no lowering
```
if the issue is with operators.
Sometimes it's helpful to see how did this operator is called.
Note that, many times, an operator being called can also be
unnexpected.
Turns out, that is the same operator. so adding the @op(torch.ops.aten._embedding_bag_forward_only)
on top of the same op works.
We can turn on logging with
`xla_env.config.debug_print_each_op` and it will print each operator that is being run.
Now the resulting PR is: https://github.com/pytorch/xla/pull/7583
The logs looks like this:
After this `python models/dlrm.py` runs.
```
2024-06-16 15:03:13,726 - root - DEBUG - FUNCTION: aten::view
2024-06-16 15:03:13,726 - root - DEBUG - FUNCTION: aten::gelu
2024-06-16 15:03:13,729 - root - DEBUG - FUNCTION: aten::view
2024-06-16 15:03:13,729 - root - DEBUG - FUNCTION: aten::t
2024-06-16 15:03:13,729 - root - DEBUG - FUNCTION: transpose
2024-06-16 15:03:13,729 - root - DEBUG - DISPATCH: aten::transpose.int
2024-06-16 15:03:13,730 - root - DEBUG - FUNCTION: permute
2024-06-16 15:03:13,730 - root - DEBUG - DISPATCH: aten::permute
2024-06-16 15:03:13,731 - root - DEBUG - FUNCTION: aten::addmm
2024-06-16 15:03:13,737 - root - DEBUG - FUNCTION: aten::view
2024-06-16 15:03:13,739 - root - DEBUG - FUNCTION: aten::add.Tensor
2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::slice.Tensor
2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::select.int
2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::t
2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: transpose
2024-06-16 15:03:13,740 - root - DEBUG - DISPATCH: aten::transpose.int
2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: permute
2024-06-16 15:03:13,740 - root - DEBUG - DISPATCH: aten::permute
2024-06-16 15:03:13,740 - root - DEBUG - FUNCTION: aten::addmm
2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: aten::_log_softmax
2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: aten::view
2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: aten::t
2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: transpose
2024-06-16 15:03:13,741 - root - DEBUG - DISPATCH: aten::transpose.int
2024-06-16 15:03:13,741 - root - DEBUG - FUNCTION: permute
2024-06-16 15:03:13,741 - root - DEBUG - DISPATCH: aten::permute
2024-06-16 15:03:13,764 - root - DEBUG - FUNCTION: aten::addmm
2024-06-16 15:03:13,788 - root - DEBUG - FUNCTION: aten::view
2024-06-16 15:03:13,790 - root - DEBUG - FUNCTION: aten::_log_softmax
```
NOTE:
The _embedding_bag implementation is actually very crude, just sufficient to make
the model pass.
1 change: 1 addition & 0 deletions experimental/torch_xla2/torch_xla2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class Configuration:
debug_print_each_op: bool = False
debug_accuracy_for_each_op: bool = False
debug_mixed_tensor: bool = False
use_int32_for_index: bool = False

# Flash attention
Expand Down
72 changes: 72 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Torch ops implemented using jax."""

import functools
import sys

import jax
from jax import numpy as jnp
from jax.experimental.sparse import BCOO

import numpy as np
import torch
from torch_xla2.ops import ops_registry
Expand Down Expand Up @@ -294,6 +297,75 @@ def _aten_embedding(a, w, padding_idx=-1):
return jnp.take(a, w, axis=0)


#- func: _embedding_bag_forward_only(
# Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False,
# int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
@op(torch.ops.aten._embedding_bag)
@op(torch.ops.aten._embedding_bag_forward_only)
def _aten__embedding_bag(
weight,
indices,
offsets=None,
scale_grad_by_freq=False,
mode=0,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1):
"""Jax implementation of the PyTorch _embedding_bag function.
Args:
weight: The learnable weights of the module of shape (num_embeddings, embedding_dim).
indices: A LongTensor containing the indices to extract.
offsets: A LongTensor containing the starting offset of each bag.
scale_grad_by_freq: Whether to scale gradients by the inverse of frequency of the words in the mini-batch.
mode: 0 = "sum", 1 = "mean" or 2 = "max"
sparse: Whether the gradients with respect to weight should be a sparse tensor.
per_sample_weights: If given, each embedding vector is weighted by per_sample_weights
include_last_offset: Whether to include the last offset as a valid bag.
padding_idx: If specified, the entries at padding_idx do not contribute to the gradient.
Returns:
A tuple of (output, offset2bag, bag_size, max_indices).
"""
embedded = _aten_embedding(weight, indices, padding_idx)

def static_dynamic_slice(x, start, size):
return jax.lax.dynamic_slice_in_dim(x, start, size)


# TODO not jittable
def reduce_by_segment(start, size, x, reducer):
res = []
for starti, sizei in zip(start, size):
res.append(reducer(static_dynamic_slice(x, starti, sizei), axis=0))
return jnp.stack(res)

def segsum(x, offsets, reducer):
start, end = offsets, jnp.concat([offsets[1:], jnp.array([x.shape[0]])])
return reduce_by_segment(start, end - start, x, reducer)

if mode not in (0, 1, 2):
raise ValueError("Invalid mode. Please choose 0 (sum) or 1 (mean).")
if mode == 0: # sum
reducer = jnp.sum
elif mode == 1: # mean
reducer = jnp.mean
elif mode == 2: # max
reducer = jnp.max

if indices.ndim == 1 and offsets is not None:
output = segsum(embedded, offsets, reducer)
else:
output = reducer(embedded, axis=1)

# TODO: return output, offset2bag, bag_size, max_indices
return output, None, None, None





@op(torch.ops.aten.rsqrt)
def _aten_rsqrt(x):
if isinstance(x, int):
Expand Down
Loading

0 comments on commit 6d375f2

Please sign in to comment.