diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index c59387b34e0..fd3873586fc 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -3406,6 +3406,20 @@ TEST_F(AtenXlaTensorTest, TestSiLU) { ExpectCounterChanged("xla::silu_out", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestSiLUBackward) { + auto testfn = [&](const std::vector& inputs) -> torch::Tensor { + return torch::silu(inputs[0]); + }; + ForEachDevice([&](const torch::Device& device) { + TestBackward( + {torch::rand({2, 2}, + torch::TensorOptions(torch::kFloat).requires_grad(true))}, + device, testfn); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::silu_backward", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestSigmoid) { torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::sigmoid(a); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 7617d3fb642..3bda92aa522 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2879,6 +2879,15 @@ at::Tensor& XLANativeFunctions::silu_out(const at::Tensor& self, return out; } +at::Tensor XLANativeFunctions::silu_backward(const at::Tensor& grad_output, + const at::Tensor& self) { + XLA_FN_COUNTER("xla::"); + XLATensor grad_output_tensor = bridge::GetXlaTensor(grad_output); + XLATensor self_tensor = bridge::GetXlaTensor(self); + return bridge::AtenFromXlaTensor( + XLATensor::silu_backward(grad_output_tensor, self_tensor)); +} + at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 4fbfa5a6c55..78077d6b395 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -182,6 +182,13 @@ xla::XlaOp BuildSigmoid(xla::XlaOp input) { return half + half * xla::Tanh(half * input); } +xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input) { + const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); + xla::XlaOp one = xla::One(input.builder(), shape.element_type()); + xla::XlaOp input_sigmoid = BuildSigmoid(input); + return grad_output * (input_sigmoid * (one + input * (one - input_sigmoid))); +} + xla::XlaOp BuildReciprocal(xla::XlaOp input) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp one = xla::One(input.builder(), shape.element_type()); diff --git a/torch_xla/csrc/elementwise.h b/torch_xla/csrc/elementwise.h index 336feb6eb61..5a61b888762 100644 --- a/torch_xla/csrc/elementwise.h +++ b/torch_xla/csrc/elementwise.h @@ -48,6 +48,10 @@ xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input, // Sigmoid(x) = (tanh(x ∗ 0.5) + 1) ∗ 0.5 xla::XlaOp BuildSigmoid(xla::XlaOp input); +// Computes the backward of Silu +// grad_output * (sigmoid(input) * (1 + input * (1 - sigmoid(input)))) +xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input); + // Computes the reciprocal function. // Reciprocal(x) = 1 / x xla::XlaOp BuildReciprocal(xla::XlaOp input); diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 5e01b62cff6..184008d06f9 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -219,6 +219,25 @@ NodePtr SiLU(const Value& input) { std::move(lower_fn)); } +NodePtr SiLUBackward(const Value& grad_output, const Value& input) { + auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1)); + return node.ReturnOp(BuildSiLUBackward(xla_grad_output, xla_input), loctx); + }; + auto lower_for_shape_fn = + [](absl::Span operands) -> xla::XlaOp { + return BuildSiLUBackward(operands[0], operands[1]); + }; + return GenericOp(OpKind(at::aten::silu), {grad_output, input}, + [&]() { + return InferOutputShape( + {grad_output.shape(), input.shape()}, + lower_for_shape_fn); + }, + std::move(lower_fn)); +} + NodePtr Sigmoid(const Value& input) { auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 7e65b49de6c..e69221a34f8 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -128,6 +128,8 @@ NodePtr Sigmoid(const Value& input); NodePtr SiLU(const Value& input); +NodePtr SiLUBackward(const Value& grad_output, const Value& input); + NodePtr SigmoidBackward(const Value& grad_output, const Value& output); NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output, diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 835763cf85f..f9e2c154ec1 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -966,6 +966,7 @@ class XLATensor { xla::int64_t index); static void silu_out(XLATensor& input, XLATensor& out); + static XLATensor silu_backward(XLATensor& grad_output, XLATensor& input); static XLATensor sigmoid(const XLATensor& input); static XLATensor sigmoid_backward(const XLATensor& grad_output, const XLATensor& output); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index fe53c482621..045abad0068 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2286,6 +2286,11 @@ void XLATensor::silu_out(XLATensor& input, XLATensor& out) { out.SetInPlaceIrValue(ir::ops::SiLU(input.GetIrValue())); } +XLATensor XLATensor::silu_backward(XLATensor& grad_output, XLATensor& input) { + return input.CreateFrom( + ir::ops::SiLUBackward(grad_output.GetIrValue(), input.GetIrValue())); +} + XLATensor XLATensor::sigmoid(const XLATensor& input) { return input.CreateFrom(ir::ops::Sigmoid(input.GetIrValue())); } diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 29cd3f5ddab..ace9c86d6b4 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -111,6 +111,7 @@ supported: - rsqrt - select.int - silu.out + - silu_backward - sigmoid - sin - sinh