diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index dfc27bfaabec5c..e93a35254914d2 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -296,7 +296,6 @@ cc_library( "py_device.cc", "py_device_list.cc", "py_executable.cc", - "py_host_callback.cc", "py_memory_space.cc", "py_program.cc", "py_values.cc", @@ -310,7 +309,6 @@ cc_library( "py_device.h", "py_device_list.h", "py_executable.h", - "py_host_callback.h", "py_memory_space.h", "py_program.h", "py_values.h", @@ -333,6 +331,8 @@ cc_library( ":nb_helpers", ":nb_numpy", ":pprof_profile_builder", + ":py_client_cpu", + ":py_host_callback", ":py_host_callback_proto_cc", ":python_ref_manager", ":traceback", @@ -427,6 +427,48 @@ cc_library( ] + if_google(["@com_google_protobuf//:any_cc_proto"]), ) +cc_library( + name = "py_host_callback", + srcs = ["py_host_callback.cc"], + hdrs = ["py_host_callback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":callback", + ":py_host_callback_proto_cc", + ":python_ref_manager", + ":types", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/pjrt:host_callback", + "//xla/pjrt:pjrt_compiler", + "//xla/python/ifrt", + "//xla/python/pjrt_ifrt", + "//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "//xla/tsl/concurrency:ref_count", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + ] + if_google([ + "@com_google_protobuf//:any_cc_proto", + ]), +) + cc_library( name = "callback", srcs = [ @@ -446,6 +488,7 @@ cc_library( ":python_ref_manager", "//xla:comparison_util", "//xla:xla_data_proto_cc", + "//xla/ffi", "//xla/pjrt:host_callback", "//xla/pjrt:transpose", "//xla/service:custom_call_status", @@ -462,6 +505,46 @@ cc_library( ], ) +cc_library( + name = "py_client_cpu", + srcs = ["py_client_cpu.cc"], + hdrs = ["py_client_cpu.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":callback", + ":nb_numpy", + ":py_host_callback", + ":types", + "//xla:comparison_util", + "//xla:shape_util", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/pjrt:exceptions", + "//xla/pjrt:host_callback", + "//xla/pjrt:transpose", + "//xla/python/ifrt", + "//xla/service:custom_call_status", + "//xla/service:custom_call_target_registry", + "//xla/service:platform_util", + "//xla/tsl/concurrency:ref_count", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:errors", + "@nanobind", + ], +) + cc_library( name = "py_client_gpu", srcs = if_google( diff --git a/third_party/xla/xla/python/callback.cc b/third_party/xla/xla/python/callback.cc index 60ecf6ba3db19b..7f09490504a937 100644 --- a/third_party/xla/xla/python/callback.cc +++ b/third_party/xla/xla/python/callback.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/types/span.h" #include "nanobind/nanobind.h" #include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "xla/ffi/ffi.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/transpose.h" #include "xla/primitive_util.h" @@ -181,4 +182,16 @@ void XlaPythonCpuCallback(void* output, void** inputs, } } +absl::StatusOr CpuCallback::FfiCall(nb::tuple args) { + nb::tuple result_tuple; + try { + auto result_object = callable_(*nb::borrow(args)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return absl::InternalError( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + return result_tuple; +} + } // namespace xla diff --git a/third_party/xla/xla/python/callback.h b/third_party/xla/xla/python/callback.h index e77647aa86f9a1..95f7406ec506ae 100644 --- a/third_party/xla/xla/python/callback.h +++ b/third_party/xla/xla/python/callback.h @@ -76,6 +76,8 @@ class CpuCallback { absl::StatusOr Call(nanobind::tuple args); + absl::StatusOr FfiCall(nanobind::tuple args); + private: nanobind::callable callable_; std::vector args_; diff --git a/third_party/xla/xla/python/py_client_cpu.cc b/third_party/xla/xla/python/py_client_cpu.cc new file mode 100644 index 00000000000000..4d4702e1b2a8d7 --- /dev/null +++ b/third_party/xla/xla/python/py_client_cpu.cc @@ -0,0 +1,136 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/py_client_cpu.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/callback.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/py_host_callback.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace xla { + +absl::Status XlaFfiPythonCpuCallback( + std::vector>* callbacks, + uint64_t index, ffi::RemainingArgs args, ffi::RemainingRets rets) { + auto loaded_callback = llvm::dyn_cast_or_null( + callbacks->at(index).get()); + if (loaded_callback == nullptr) { + return absl::InternalError( + "Expected a PyCpuLoadedHostCallback, got something else."); + } + CpuCallback* callback = loaded_callback->cpu_callback(); + + nb::gil_scoped_acquire gil; + auto nb_args = nb::steal(PyTuple_New(args.size())); + for (size_t i = 0; i < args.size(); ++i) { + auto arg = args.get(i); + auto ptype = arg->element_type(); + if (ptype == TOKEN) { + PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr()); + } else { + TF_ASSIGN_OR_RETURN(auto dtype, PrimitiveTypeToNbDtype(ptype)); + // We pass in data using default numpy layout i.e., std::nullopt. + auto array = nb_numpy_ndarray(dtype, arg->dimensions(), std::nullopt, + arg.value().untyped_data()); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); + } + } + + EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + absl::StatusOr maybe_result_tuple = + callback->FfiCall(std::move(nb_args)); + LeaveHostCallback(); + TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); + + for (size_t i = 0; i < rets.size(); ++i) { + auto arg = rets.get(i).value(); + auto ptype = arg->element_type(); + if (ptype == TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + TF_ASSIGN_OR_RETURN(auto expected_shape, ShapeUtil::MakeValidatedShape( + ptype, arg->dimensions())); + auto expected_strides = ByteStridesForShape(expected_shape); + if (strides == expected_strides) { + std::memcpy(arg->untyped_data(), array.data(), arg->size_bytes()); + } else { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + options.dims = dims; + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions_size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + TF_ASSIGN_OR_RETURN(auto plan, + callback->transpose_cache().GetOrCreate(options)); + plan->Execute(array.data(), arg->untyped_data()); + } + } + + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaFfiPythonCpuCallback, XlaFfiPythonCpuCallback, + ffi::Ffi::Bind() + .Ctx>>>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_ffi_python_cpu_callback", + "HOST", kXlaFfiPythonCpuCallback); + +} // namespace xla diff --git a/third_party/xla/xla/python/py_client_cpu.h b/third_party/xla/xla/python/py_client_cpu.h new file mode 100644 index 00000000000000..1fea2914e47de5 --- /dev/null +++ b/third_party/xla/xla/python/py_client_cpu.h @@ -0,0 +1,27 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_PY_CLIENT_CPU_H_ +#define XLA_PYTHON_PY_CLIENT_CPU_H_ + +#include "xla/ffi/ffi.h" + +namespace xla { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback); + +} // namespace xla + +#endif // XLA_PYTHON_PY_CLIENT_CPU_H_ diff --git a/third_party/xla/xla/python/py_host_callback.h b/third_party/xla/xla/python/py_host_callback.h index da0287aa4d1e8b..fc552aad588a8d 100644 --- a/third_party/xla/xla/python/py_host_callback.h +++ b/third_party/xla/xla/python/py_host_callback.h @@ -60,6 +60,8 @@ class PyCpuLoadedHostCallback final return absl::bit_cast(cpu_callback_.get()); } + CpuCallback* cpu_callback() { return cpu_callback_.get(); } + // LoadedHostCallback implementation. ~PyCpuLoadedHostCallback() override = default;