Skip to content

Commit

Permalink
add SCATTER_API definition for scatter_mul in scatter.cpp & scatter.h (
Browse files Browse the repository at this point in the history
…#344)

Co-authored-by: zenghongtai <[email protected]>
  • Loading branch information
HunterTracer and zenghongtai authored Dec 9, 2022
1 parent 111ffc4 commit fe83843
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
7 changes: 4 additions & 3 deletions csrc/scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,10 @@ scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}

torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
SCATTER_API torch::Tensor
scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
}

Expand Down
5 changes: 5 additions & 0 deletions csrc/scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);

SCATTER_API torch::Tensor
scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);

SCATTER_API torch::Tensor
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
Expand Down

0 comments on commit fe83843

Please sign in to comment.