Skip to content

Commit

Permalink
[enhancement] Refactor onedal/datatypes in preparation for dlpa…
Browse files Browse the repository at this point in the history
…ck support (#2195)

* move to numpy and sycl_usm folders

* fix pre-commit

* fix pickling

* forgotten numpy namespace is BS

* missing another numpy

* missing another numpy

* add missing sycl_usm

* remove table_metadata

* remove unneeded includes and move dtype_dispatcher into a central location

* remove header reference

* missed save

* Delete onedal/datatypes/dtype_dispatcher.hpp

* Revert "Delete onedal/datatypes/dtype_dispatcher.hpp"

This reverts commit bfc66b6.

* helper -> utils

* move macro to a central spot
  • Loading branch information
icfaust authored Jan 28, 2025
1 parent 1b6d537 commit b86d5fc
Show file tree
Hide file tree
Showing 17 changed files with 130 additions and 130 deletions.
26 changes: 13 additions & 13 deletions onedal/basic_statistics/basic_statistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "onedal/version.hpp"

#define NO_IMPORT_ARRAY // import_array called in table.cpp
#include "onedal/datatypes/data_conversion.hpp"
#include "onedal/datatypes/numpy/data_conversion.hpp"

#include <string>
#include <regex>
Expand Down Expand Up @@ -210,30 +210,30 @@ void init_partial_compute_result(py::module_& m) {
.def(py::pickle(
[](const result_t& res) {
return py::make_tuple(
py::cast<py::object>(convert_to_pyobject(res.get_partial_n_rows())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_min())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_max())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum_squares())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_n_rows())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_min())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_max())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_sum())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_sum_squares())),
py::cast<py::object>(
convert_to_pyobject(res.get_partial_sum_squares_centered())));
numpy::convert_to_pyobject(res.get_partial_sum_squares_centered())));
},
[](py::tuple t) {
if (t.size() != 6)
throw std::runtime_error("Invalid state!");
result_t res;
if (py::cast<int>(t[0].attr("size")) != 0)
res.set_partial_n_rows(convert_to_table(t[0]));
res.set_partial_n_rows(numpy::convert_to_table(t[0]));
if (py::cast<int>(t[1].attr("size")) != 0)
res.set_partial_min(convert_to_table(t[1]));
res.set_partial_min(numpy::convert_to_table(t[1]));
if (py::cast<int>(t[2].attr("size")) != 0)
res.set_partial_max(convert_to_table(t[2]));
res.set_partial_max(numpy::convert_to_table(t[2]));
if (py::cast<int>(t[3].attr("size")) != 0)
res.set_partial_sum(convert_to_table(t[3]));
res.set_partial_sum(numpy::convert_to_table(t[3]));
if (py::cast<int>(t[4].attr("size")) != 0)
res.set_partial_sum_squares(convert_to_table(t[4]));
res.set_partial_sum_squares(numpy::convert_to_table(t[4]));
if (py::cast<int>(t[5].attr("size")) != 0)
res.set_partial_sum_squares_centered(convert_to_table(t[5]));
res.set_partial_sum_squares_centered(numpy::convert_to_table(t[5]));

return res;
}));
Expand Down
15 changes: 8 additions & 7 deletions onedal/covariance/covariance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "oneapi/dal/algo/covariance.hpp"

#define NO_IMPORT_ARRAY // import_array called in table.cpp
#include "onedal/datatypes/data_conversion.hpp"
#include "onedal/datatypes/numpy/data_conversion.hpp"

#include "onedal/common.hpp"
#include "onedal/version.hpp"
Expand Down Expand Up @@ -141,20 +141,21 @@ inline void init_partial_compute_result(pybind11::module_& m) {
.def(py::pickle(
[](const result_t& res) {
return py::make_tuple(
py::cast<py::object>(convert_to_pyobject(res.get_partial_n_rows())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_crossproduct())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum())));
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_n_rows())),
py::cast<py::object>(
numpy::convert_to_pyobject(res.get_partial_crossproduct())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_sum())));
},
[](py::tuple t) {
if (t.size() != 3)
throw std::runtime_error("Invalid state!");
result_t res;
if (py::cast<int>(t[0].attr("size")) != 0)
res.set_partial_n_rows(convert_to_table(t[0]));
res.set_partial_n_rows(numpy::convert_to_table(t[0]));
if (py::cast<int>(t[1].attr("size")) != 0)
res.set_partial_crossproduct(convert_to_table(t[1]));
res.set_partial_crossproduct(numpy::convert_to_table(t[1]));
if (py::cast<int>(t[2].attr("size")) != 0)
res.set_partial_sum(convert_to_table(t[2]));
res.set_partial_sum(numpy::convert_to_table(t[2]));
return res;
}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,27 @@ constexpr inline void apply(Op&& op, Args&&... args) {

#endif // Version check

#define SET_CTYPE_FROM_DAL_TYPE(_T, _FUNCT, _EXCEPTION) \
switch (_T) { \
case dal::data_type::float32: { \
_FUNCT(float); \
break; \
} \
case dal::data_type::float64: { \
_FUNCT(double); \
break; \
} \
case dal::data_type::int32: { \
_FUNCT(std::int32_t); \
break; \
} \
case dal::data_type::int64: { \
_FUNCT(std::int64_t); \
break; \
} \
default: _EXCEPTION; \
};

namespace oneapi::dal::python {

using supported_types_t = std::tuple<float,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
#include "oneapi/dal/table/homogen.hpp"
#include "oneapi/dal/table/detail/homogen_utils.hpp"

#include "onedal/datatypes/data_conversion.hpp"
#include "onedal/datatypes/utils/numpy_helpers.hpp"
#include "onedal/datatypes/numpy/data_conversion.hpp"
#include "onedal/datatypes/numpy/numpy_utils.hpp"
#include "onedal/version.hpp"

#if ONEDAL_VERSION <= 20230100
Expand All @@ -32,7 +32,7 @@
#include "oneapi/dal/table/csr.hpp"
#endif

namespace oneapi::dal::python {
namespace oneapi::dal::python::numpy {

#if ONEDAL_VERSION <= 20230100
typedef oneapi::dal::detail::csr_table csr_table_t;
Expand Down Expand Up @@ -432,4 +432,4 @@ PyObject *convert_to_pyobject(const dal::table &input) {
return res;
}

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::numpy
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@

#include "oneapi/dal/table/common.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::numpy {

namespace py = pybind11;

PyObject *convert_to_pyobject(const dal::table &input);
dal::table convert_to_table(py::object inp_obj, py::object queue = py::none());

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::numpy
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
* limitations under the License.
*******************************************************************************/

#include "onedal/datatypes/utils/numpy_helpers.hpp"
#include "onedal/datatypes/numpy/numpy_utils.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::numpy {

template <typename Key, typename Value>
auto reverse_map(const std::map<Key, Value>& input) {
Expand Down Expand Up @@ -50,4 +50,4 @@ npy_dtype_t convert_dal_to_npy_type(dal::data_type type) {
return get_dal_to_npy_map().at(type);
}

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::numpy
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
#define array_data(a) PyArray_DATA((PyArrayObject *)a)
#define array_size(a, i) PyArray_DIM((PyArrayObject *)a, i)

namespace oneapi::dal::python {
namespace oneapi::dal::python::numpy {

using npy_dtype_t = decltype(NPY_FLOAT);
using npy_to_dal_t = std::map<npy_dtype_t, dal::data_type>;
Expand All @@ -152,4 +152,4 @@ const dal_to_npy_t &get_dal_to_npy_map();
dal::data_type convert_npy_to_dal_type(npy_dtype_t);
npy_dtype_t convert_dal_to_npy_type(dal::data_type);

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::numpy
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
#include "oneapi/dal/table/detail/homogen_utils.hpp"

#include "onedal/common/sycl_interfaces.hpp"
#include "onedal/datatypes/data_conversion_sua_iface.hpp"
#include "onedal/datatypes/utils/dtype_conversions.hpp"
#include "onedal/datatypes/utils/dtype_dispatcher.hpp"
#include "onedal/datatypes/utils/sua_iface_helpers.hpp"
#include "onedal/datatypes/sycl_usm/data_conversion.hpp"
#include "onedal/datatypes/sycl_usm/dtype_conversion.hpp"
#include "onedal/datatypes/sycl_usm/sycl_usm_utils.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::sycl_usm {

using namespace pybind11::literals;
// Please follow <https://intelpython.github.io/dpctl/latest/
Expand Down Expand Up @@ -128,7 +127,7 @@ dal::table convert_to_homogen_impl(py::object obj) {
}

// Convert oneDAL table with zero-copy by use of `__sycl_usm_array_interface__` protocol.
dal::table convert_from_sua_iface(py::object obj) {
dal::table convert_to_table(py::object obj) {
// Get `__sycl_usm_array_interface__` dictionary representing USM allocations.
auto sua_iface_dict = get_sua_interface(obj);

Expand Down Expand Up @@ -236,6 +235,6 @@ void define_sycl_usm_array_property(py::class_<dal::table>& table_obj) {
table_obj.def_property_readonly("__sycl_usm_array_interface__", &construct_sua_iface);
}

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::sycl_usm

#endif // ONEDAL_DATA_PARALLEL
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@

#include "oneapi/dal/table/common.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::sycl_usm {

namespace py = pybind11;

// Convert oneDAL table with zero-copy by use of `__sycl_usm_array_interface__` protocol.
dal::table convert_from_sua_iface(py::object obj);
dal::table convert_to_table(py::object obj);

// Create a dictionary for `__sycl_usm_array_interface__` protocol from oneDAL table properties.
py::dict construct_sua_iface(const dal::table& input);
Expand All @@ -37,4 +37,4 @@ py::dict construct_sua_iface(const dal::table& input);
// USM allocations.
void define_sycl_usm_array_property(py::class_<dal::table>& t);

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::sycl_usm
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
#include "oneapi/dal/common.hpp"
#include "oneapi/dal/detail/common.hpp"

#include "onedal/datatypes/utils/dtype_conversions.hpp"
#include "onedal/datatypes/utils/dtype_dispatcher.hpp"
#include "onedal/datatypes/sycl_usm/dtype_conversion.hpp"
#include "onedal/datatypes/dtype_dispatcher.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::sycl_usm {

using fwd_map_t = std::unordered_map<std::string, dal::data_type>;
using inv_map_t = std::unordered_map<dal::data_type, std::string>;
Expand Down Expand Up @@ -139,4 +139,4 @@ std::string convert_dal_to_sua_type(dal::data_type dtype) {
return get_inv_map().at(dtype);
}

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::sycl_usm
33 changes: 33 additions & 0 deletions onedal/datatypes/sycl_usm/dtype_conversion.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#pragma once

#include <string>

#include <pybind11/pybind11.h>

#include "oneapi/dal/common.hpp"
#include "onedal/datatypes/dtype_dispatcher.hpp"

namespace py = pybind11;

namespace oneapi::dal::python::sycl_usm {

dal::data_type convert_sua_to_dal_type(std::string dtype);
std::string convert_dal_to_sua_type(dal::data_type dtype);

} // namespace oneapi::dal::python::sycl_usm
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
#include "oneapi/dal/table/detail/homogen_utils.hpp"

#include "onedal/common/sycl_interfaces.hpp"
#include "onedal/datatypes/data_conversion_sua_iface.hpp"
#include "onedal/datatypes/utils/dtype_conversions.hpp"
#include "onedal/datatypes/utils/dtype_dispatcher.hpp"
#include "onedal/datatypes/sycl_usm/data_conversion.hpp"
#include "onedal/datatypes/sycl_usm/dtype_conversion.hpp"

/* __sycl_usm_array_interface__
*
Expand All @@ -53,7 +52,7 @@
* api_reference/dpctl/sycl_usm_array_interface.html#sycl-usm-array-interface-attribute>
*/

namespace oneapi::dal::python {
namespace oneapi::dal::python::sycl_usm {

// Convert a string encoding elemental data type of the array to oneDAL homogen table data type.
dal::data_type get_sua_dtype(const py::dict& sua) {
Expand Down Expand Up @@ -197,6 +196,6 @@ py::tuple get_npy_strides(const dal::data_layout& data_layout,
return strides;
}

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::sycl_usm

#endif // ONEDAL_DATA_PARALLEL
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@
#include "oneapi/dal/table/detail/homogen_utils.hpp"

#include "onedal/common/sycl_interfaces.hpp"
#include "onedal/datatypes/data_conversion_sua_iface.hpp"
#include "onedal/datatypes/utils/dtype_conversions.hpp"
#include "onedal/datatypes/utils/dtype_dispatcher.hpp"
#include "onedal/datatypes/sycl_usm/data_conversion.hpp"
#include "onedal/datatypes/sycl_usm/dtype_conversion.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::sycl_usm {

dal::data_type get_sua_dtype(const py::dict& sua);

Expand Down Expand Up @@ -62,6 +61,6 @@ py::tuple get_npy_strides(const dal::data_layout& data_layout,
npy_intp row_count,
npy_intp column_count);

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::sycl_usm

#endif // ONEDAL_DATA_PARALLEL
16 changes: 8 additions & 8 deletions onedal/datatypes/table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
#include "oneapi/dal/table/homogen.hpp"

#ifdef ONEDAL_DATA_PARALLEL
#include "onedal/datatypes/data_conversion_sua_iface.hpp"
#include "onedal/datatypes/sycl_usm/data_conversion.hpp"
#endif // ONEDAL_DATA_PARALLEL

#include "onedal/datatypes/data_conversion.hpp"
#include "onedal/datatypes/utils/numpy_helpers.hpp"
#include "onedal/datatypes/numpy/data_conversion.hpp"
#include "onedal/datatypes/numpy/numpy_utils.hpp"
#include "onedal/common/pybind11_helpers.hpp"
#include "onedal/version.hpp"

Expand Down Expand Up @@ -74,25 +74,25 @@ ONEDAL_PY_INIT_MODULE(table) {
});
table_obj.def_property_readonly("dtype", [](const table& t) {
// returns a numpy dtype, even if source was not from numpy
return py::dtype(convert_dal_to_npy_type(t.get_metadata().get_data_type(0)));
return py::dtype(numpy::convert_dal_to_npy_type(t.get_metadata().get_data_type(0)));
});

#ifdef ONEDAL_DATA_PARALLEL
define_sycl_usm_array_property(table_obj);
sycl_usm::define_sycl_usm_array_property(table_obj);
#endif // ONEDAL_DATA_PARALLEL

m.def("to_table", [](py::object obj, py::object queue) {
#ifdef ONEDAL_DATA_PARALLEL
if (py::hasattr(obj, "__sycl_usm_array_interface__")) {
return convert_from_sua_iface(obj);
return sycl_usm::convert_to_table(obj);
}
#endif // ONEDAL_DATA_PARALLEL

return convert_to_table(obj, queue);
return numpy::convert_to_table(obj, queue);
});

m.def("from_table", [](const dal::table& t) -> py::handle {
auto* obj_ptr = convert_to_pyobject(t);
auto* obj_ptr = numpy::convert_to_pyobject(t);
return obj_ptr;
});
}
Expand Down
Loading

0 comments on commit b86d5fc

Please sign in to comment.