Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Change AtenScatterReduce to AtenScatterReduceTwoOp with tm_tensor pass for onnx.ScatterElements #3754

Merged
merged 4 commits into from
Oct 6, 2024

Conversation

AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Oct 2, 2024

From onnx.ScatterElements: nod-ai/SHARK-ModelDev#823
e2e test: nod-ai/SHARK-TestSuite#363

The final target examples, the index and src size is [?] dynamic, here we use static size to explain the algorithm.
torch.scatter.reduce step by step example:
self[index[i]] += src[i]

src = [1, 2, 3, 4, 5, 6]
index = [0, 1, 0, 1, 2, 1]
self = [1, 2, 3, 4]
Step 0:
self[index[0]] += src[0]
self[0] += 1  = 1+1 = 2
1+1 = 2
self = [2, 2, 3, 4])

Step 1:
self[index[1]] += src[1]
self[1] += 2  = 2+2 = 4
self = [2, 4, 3, 4])

Step 2:
self[index[2]] += src[2]
self[0] += 3  = 2+3 = 5
self = [5, 4, 3, 4])

Step 3:
self[index[3]] += src[3]
self[1] += 4  = 4+4 = 8
self = [5, 8, 3, 4])

Step 4:
self[index[4]] += src[4]
self[2] += 5  = 3+5 = 8
self = [5, 8, 8, 4])

Step 5:
self[index[5]] += src[5]
self[1] += 6  = 8+6 = 14
self = [5, 14, 8, 4])
  1. onnx-model.mlir
module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>} : () -> !torch.vtensor<[6],si64> 
    %1 = torch.operator "onnx.ScatterElements"(%arg0, %0, %arg1) {torch.onnx.axis = 0 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> 
    return %1 : !torch.vtensor<[4],f32>
  }
}

torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops --cse --canonicalize --convert-torch-to-linalg torch-model.mlir
2. torch-model.mlir

module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %true = torch.constant.bool true
    %str = torch.constant.str "sum"
    %0 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %int0 = torch.constant.int 0
    %1 = torch.aten.scatter_reduce.two %arg0, %int0, %0, %arg1, %str, %true : !torch.vtensor<[4],f32>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>, !torch.str, !torch.bool -> !torch.vtensor<[4],f32>
    return %1 : !torch.vtensor<[4],f32>
  }
}
  1. linalg-model.mlir
#map = affine_map<(d0) -> (d0, 0)>
#map1 = affine_map<(d0) -> (d0)>
module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[6],f32> -> tensor<6xf32>
    %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[4],f32> -> tensor<4xf32>
    %true = torch.constant.bool true
    %str = torch.constant.str "sum"
    %2 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %3 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[6],si64> -> tensor<6xi64>
    %int0 = torch.constant.int 0
    %c0 = arith.constant 0 : index
    %c6 = arith.constant 6 : index
    %c1 = arith.constant 1 : index
    %4 = arith.muli %c1, %c6 : index
    %5 = arith.index_cast %4 : index to i64
    %6 = arith.index_cast %5 : i64 to index
    %c0_0 = arith.constant 0 : index
    %c6_1 = arith.constant 6 : index
    %c1_2 = arith.constant 1 : index
    %7 = tensor.empty(%6) : tensor<?x1xi32>
    %c0_i32 = arith.constant 0 : i32
    %8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %9 = tensor.empty(%6) : tensor<?xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?xf32>) -> tensor<?xf32>
    %11:2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} outs(%8, %10 : tensor<?x1xi32>, tensor<?xf32>) {
    ^bb0(%out: i32, %out_13: f32):
      %16 = linalg.index 0 : index
      %17 = arith.remsi %16, %c6_1 : index
      %18 = arith.divsi %16, %c6_1 : index
      %extracted = tensor.extract %3[%17] : tensor<6xi64>
      %extracted_14 = tensor.extract %0[%17] : tensor<6xf32>
      %19 = arith.index_cast %17 : index to i64
      %20 = arith.trunci %19 : i64 to i32
      %21 = arith.trunci %extracted : i64 to i32
      linalg.yield %21, %extracted_14 : i32, f32
    } -> (tensor<?x1xi32>, tensor<?xf32>)
    %c0_3 = arith.constant 0 : index
    %c0_4 = arith.constant 0 : index
    %c1_5 = arith.constant 1 : index
    %c1_6 = arith.constant 1 : index
    %c1_7 = arith.constant 1 : index
    %12 = tensor.empty(%6) : tensor<?x1xi32>
    %c0_i32_8 = arith.constant 0 : i32
    %13 = linalg.fill ins(%c0_i32_8 : i32) outs(%12 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %c0_9 = arith.constant 0 : index
    %dim = tensor.dim %11#0, %c0_9 : tensor<?x1xi32>
    %c1_10 = arith.constant 1 : index
    %c1_11 = arith.constant 1 : index
    %inserted_slice = tensor.insert_slice %11#0 into %13[0, 0] [%dim, 1] [1, 1] : tensor<?x1xi32> into tensor<?x1xi32>
    %c1_12 = arith.constant 1 : index
    %14 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(false) ins(%11#1, %inserted_slice : tensor<?xf32>, tensor<?x1xi32>) outs(%1 : tensor<4xf32>) {
    ^bb0(%arg2: f32, %arg3: f32):
      %16 = arith.addf %arg2, %arg3 : f32
      tm_tensor.yield %16 : f32
    } -> tensor<4xf32>
    %cast = tensor.cast %14 : tensor<4xf32> to tensor<4xf32>
    %15 = torch_c.from_builtin_tensor %cast : tensor<4xf32> -> !torch.vtensor<[4],f32>
    return %15 : !torch.vtensor<[4],f32>
  }
}
tests model-run onnx-import torch-mlir iree-compile inference
onnx/operators/ScatterElements passed passed passed passed passed

Copy link
Collaborator

@Shukla-Gaurav Shukla-Gaurav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few minor comments!

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir Outdated Show resolved Hide resolved
test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir Outdated Show resolved Hide resolved
@Shukla-Gaurav
Copy link
Collaborator

@AmosLewis Can you please add a end 2 end test case in SHARK-TestSuite to make sure this lowering works fine.

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Oct 4, 2024

based on the e2e op iree inference, the lowering is not correct. nod-ai/SHARK-TestSuite#363 (comment)
The trick I try to use here is to use self as the output of linalg.generic. Based on my understanding, after tensor.scatter, the output/self should updated, then in the next loop of linalg.generic, we have an updated output self as new input self. But based on the inference result, the output/self does not update successfully. In my example, the output self value is still [1,2,3,4], not expected [5,14,8,4]. Which means the trick does not work. The self is not updated. Have no idea how to fix it. Need you guys help. @rsuderman @Shukla-Gaurav @vivekkhandelwal1 @zjgarvey

%0 = src = [1, 2, 3, 4, 5, 6]
%1 = self = [1, 2, 3, 4]
%2 = index = [0, 1, 0, 1, 2, 1]

#map = affine_map<(d0) -> (d0)>
module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[6],f32> -> tensor<6xf32>
    %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[4],f32> -> tensor<4xf32>
    %str = torch.constant.str "sum"
    %2 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %3 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[6],si64> -> tensor<6xi64>
    %int0 = torch.constant.int 0
    %cast = tensor.cast %3 : tensor<6xi64> to tensor<?xi64>
    %cast_0 = tensor.cast %0 : tensor<6xf32> to tensor<?xf32>
    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%cast, %cast_0 : tensor<?xi64>, tensor<?xf32>) outs(%1 : tensor<4xf32>) {
    ^bb0(%in: i64, %in_2: f32, %out: f32):
      %6 = arith.index_cast %in : i64 to index
      %extracted = tensor.extract %1[%6] : tensor<4xf32>
      %7 = arith.addf %extracted, %in_2 : f32
      %from_elements = tensor.from_elements %7 : tensor<1xf32>
      %from_elements_3 = tensor.from_elements %6 : tensor<1xindex>
      %scatter = tensor.scatter %from_elements into %1[%from_elements_3] scatter_dims([0]) unique : (tensor<1xf32>, tensor<4xf32>, tensor<1xindex>) -> tensor<4xf32>
      linalg.yield %out : f32
    } -> tensor<4xf32>
    %cast_1 = tensor.cast %4 : tensor<4xf32> to tensor<4xf32>
    %5 = torch_c.from_builtin_tensor %cast_1 : tensor<4xf32> -> !torch.vtensor<[4],f32>
    return %5 : !torch.vtensor<[4],f32>
  }
}

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Oct 4, 2024

Find a simple fix for this op. Just use the tm_tensor scatterreducetwo pass instead of writing linalg scatterreduce op. Here is the test result: nod-ai/SHARK-TestSuite#363

Status report for run: test-run using mode:onnx todtype:default backend:llvm-cpu

| tests                          | model-run   | onnx-import   | torch-mlir   | iree-compile   | inference   |
|:-------------------------------|:------------|:--------------|:-------------|:---------------|:------------|
| onnx/operators/ScatterElements | passed      | passed        | passed       | passed         | passed      |

…atterElements

This will enable the AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext
Remove the wrong AtenScatterReduce to linalg pass.
@AmosLewis AmosLewis changed the title [Linalg] Add torch.scatter.reduce to linalg lowering [ONNX] Change AtenScatterReduce to AtenScatterReduceTwoOp with tm_tensor pass for onnx.ScatterElements Oct 6, 2024
@AmosLewis AmosLewis merged commit f4840ed into llvm:main Oct 6, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants