-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Change AtenScatterReduce to AtenScatterReduceTwoOp with tm_tensor pass for onnx.ScatterElements #3754
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few minor comments!
@AmosLewis Can you please add a end 2 end test case in SHARK-TestSuite to make sure this lowering works fine. |
based on the e2e op iree inference, the lowering is not correct. nod-ai/SHARK-TestSuite#363 (comment) %0 = src = [1, 2, 3, 4, 5, 6]
|
bcd5dd0
to
07e2d59
Compare
Find a simple fix for this op. Just use the tm_tensor scatterreducetwo pass instead of writing linalg scatterreduce op. Here is the test result: nod-ai/SHARK-TestSuite#363
|
07e2d59
to
5191b15
Compare
…atterElements This will enable the AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext Remove the wrong AtenScatterReduce to linalg pass.
5191b15
to
faec62d
Compare
From onnx.ScatterElements: nod-ai/SHARK-ModelDev#823
e2e test: nod-ai/SHARK-TestSuite#363
The final target examples, the
index
andsrc
size is [?] dynamic, here we use static size to explain the algorithm.torch.scatter.reduce step by step example:
self[index[i]] += src[i]
torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops --cse --canonicalize --convert-torch-to-linalg torch-model.mlir
2. torch-model.mlir