diff --git a/CMakeLists.txt b/CMakeLists.txt index c2d20125..17c3819e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,7 +61,6 @@ file(GLOB_RECURSE ALL_SOURCES ${CSRC}/*.cpp) if (WITH_CUDA) file(GLOB_RECURSE ALL_SOURCES ${ALL_SOURCES} ${CSRC}/*.cu) endif() -file(GLOB_RECURSE ALL_HEADERS ${CSRC}/*.h) add_library(${PROJECT_NAME} SHARED ${ALL_SOURCES}) target_include_directories(${PROJECT_NAME} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}") if(MKL_INCLUDE_FOUND) diff --git a/pyg_lib/classes/__init__.py b/pyg_lib/classes/__init__.py index 002f7d67..4e8d69fb 100644 --- a/pyg_lib/classes/__init__.py +++ b/pyg_lib/classes/__init__.py @@ -4,7 +4,7 @@ class HashMap: def __init__(self, key: Tensor) -> Tensor: - self._map = torch.classes.pyg.CPUHashMap(key) + self._map = torch.classes.pyg.HashMap(key) def get(self, query: Tensor) -> Tensor: return self._map.get(query) diff --git a/pyg_lib/csrc/classes/cpu/hash_map.cpp b/pyg_lib/csrc/classes/cpu/hash_map.cpp deleted file mode 100644 index 66b2977b..00000000 --- a/pyg_lib/csrc/classes/cpu/hash_map.cpp +++ /dev/null @@ -1,82 +0,0 @@ -#include "hash_map.h" - -#include -#include - -namespace pyg { -namespace classes { - -template -CPUHashMapImpl::CPUHashMapImpl(const at::Tensor& key) { - at::TensorArg key_arg{key, "key", 0}; - at::CheckedFrom c{"HashMap.init"}; - at::checkDeviceType(c, key, at::DeviceType::CPU); - at::checkDim(c, key_arg, 1); - at::checkContiguous(c, key_arg); - - map_.reserve(key.numel()); - - const auto num_threads = at::get_num_threads(); - const auto grain_size = std::max( - (key.numel() + num_threads - 1) / num_threads, at::internal::GRAIN_SIZE); - const auto key_data = key.data_ptr(); - - at::parallel_for(0, key.numel(), grain_size, [&](int64_t beg, int64_t end) { - for (int64_t i = beg; i < end; ++i) { - auto [iterator, inserted] = map_.insert({key_data[i], i}); - TORCH_CHECK(inserted, "Found duplicated key."); - } - }); -}; - -template -at::Tensor CPUHashMapImpl::get(const at::Tensor& query) { - at::TensorArg query_arg{query, "query", 0}; - at::CheckedFrom c{"HashMap.get"}; - at::checkDeviceType(c, query, at::DeviceType::CPU); - at::checkDim(c, query_arg, 1); - at::checkContiguous(c, query_arg); - - const auto options = at::TensorOptions().dtype(at::kLong); - const auto out = at::empty({query.numel()}, options); - auto out_data = out.data_ptr(); - - const auto num_threads = at::get_num_threads(); - const auto grain_size = - std::max((query.numel() + num_threads - 1) / num_threads, - at::internal::GRAIN_SIZE); - const auto query_data = query.data_ptr(); - - at::parallel_for(0, query.numel(), grain_size, [&](int64_t beg, int64_t end) { - for (int64_t i = beg; i < end; ++i) { - auto it = map_.find(query_data[i]); - out_data[i] = (it != map_.end()) ? it->second : -1; - } - }); - - return out; -} - -CPUHashMap::CPUHashMap(const at::Tensor& key) { - // clang-format off - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, - key.scalar_type(), - "cpu_hash_map_init", - [&] { - map_ = std::make_unique>(key); - }); - // clang-format on -} - -at::Tensor CPUHashMap::get(const at::Tensor& query) { - return map_->get(query); -} - -TORCH_LIBRARY(pyg, m) { - m.class_("CPUHashMap") - .def(torch::init()) - .def("get", &CPUHashMap::get); -} - -} // namespace classes -} // namespace pyg diff --git a/pyg_lib/csrc/classes/cpu/hash_map.h b/pyg_lib/csrc/classes/cpu/hash_map.h deleted file mode 100644 index a1908164..00000000 --- a/pyg_lib/csrc/classes/cpu/hash_map.h +++ /dev/null @@ -1,44 +0,0 @@ -#pragma once - -#include -#include "parallel_hashmap/phmap.h" - -namespace pyg { -namespace classes { - -struct IHashMap { - virtual ~IHashMap() = default; - virtual at::Tensor get(const at::Tensor& query) = 0; -}; - -template -struct CPUHashMapImpl : IHashMap { - public: - using ValueType = int64_t; - - CPUHashMapImpl(const at::Tensor& key); - at::Tensor get(const at::Tensor& query) override; - - private: - phmap::parallel_flat_hash_map< - KeyType, - ValueType, - phmap::priv::hash_default_hash, - phmap::priv::hash_default_eq, - phmap::priv::Allocator>, - 12, - std::mutex> - map_; -}; - -struct CPUHashMap : torch::CustomClassHolder { - public: - CPUHashMap(const at::Tensor& key); - at::Tensor get(const at::Tensor& query); - - private: - std::unique_ptr map_; -}; - -} // namespace classes -} // namespace pyg diff --git a/pyg_lib/csrc/classes/cpu/hash_map_impl.h b/pyg_lib/csrc/classes/cpu/hash_map_impl.h new file mode 100644 index 00000000..707f5f35 --- /dev/null +++ b/pyg_lib/csrc/classes/cpu/hash_map_impl.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include "../hash_map_impl.h" +#include "parallel_hashmap/phmap.h" + +namespace pyg { +namespace classes { + +template +struct CPUHashMapImpl : HashMapImpl { + public: + using ValueType = int64_t; + + CPUHashMapImpl(const at::Tensor& key) { + map_.reserve(key.numel()); + + const auto num_threads = at::get_num_threads(); + const auto grain_size = + std::max((key.numel() + num_threads - 1) / num_threads, + at::internal::GRAIN_SIZE); + const auto key_data = key.data_ptr(); + + at::parallel_for(0, key.numel(), grain_size, [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + auto [iterator, inserted] = map_.insert({key_data[i], i}); + TORCH_CHECK(inserted, "Found duplicated key in 'HashMap'."); + } + }); + } + + at::Tensor get(const at::Tensor& query) override { + const auto options = at::TensorOptions().dtype(at::kLong); + const auto out = at::empty({query.numel()}, options); + auto out_data = out.data_ptr(); + + const auto num_threads = at::get_num_threads(); + const auto grain_size = + std::max((query.numel() + num_threads - 1) / num_threads, + at::internal::GRAIN_SIZE); + const auto query_data = query.data_ptr(); + + at::parallel_for(0, query.numel(), grain_size, [&](int64_t b, int64_t e) { + for (int64_t i = b; i < e; ++i) { + auto it = map_.find(query_data[i]); + out_data[i] = (it != map_.end()) ? it->second : -1; + } + }); + + return out; + } + + private: + phmap::parallel_flat_hash_map< + KeyType, + ValueType, + phmap::priv::hash_default_hash, + phmap::priv::hash_default_eq, + phmap::priv::Allocator>, + 12, + std::mutex> + map_; +}; + +} // namespace classes +} // namespace pyg diff --git a/pyg_lib/csrc/classes/hash_map.cpp b/pyg_lib/csrc/classes/hash_map.cpp new file mode 100644 index 00000000..9589831e --- /dev/null +++ b/pyg_lib/csrc/classes/hash_map.cpp @@ -0,0 +1,47 @@ +#include "hash_map.h" + +#include +#include "cpu/hash_map_impl.h" + +namespace pyg { +namespace classes { + +HashMap::HashMap(const at::Tensor& key) { + at::TensorArg key_arg{key, "key", 0}; + at::CheckedFrom c{"HashMap.init"}; + at::checkDeviceType(c, key, at::DeviceType::CPU); + at::checkDim(c, key_arg, 1); + at::checkContiguous(c, key_arg); + + // clang-format off + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, + key.scalar_type(), + "hash_map_init", + [&] { + /* if (key.is_cpu) { */ + map_ = std::make_unique>(key); + /* } else { */ + /* AT_ERROR("Received invalid device type for 'HashMap'."); */ + /* } */ + }); + // clang-format on +} + +at::Tensor HashMap::get(const at::Tensor& query) { + at::TensorArg query_arg{query, "query", 0}; + at::CheckedFrom c{"HashMap.get"}; + at::checkDeviceType(c, query, at::DeviceType::CPU); + at::checkDim(c, query_arg, 1); + at::checkContiguous(c, query_arg); + + return map_->get(query); +} + +TORCH_LIBRARY(pyg, m) { + m.class_("HashMap") + .def(torch::init()) + .def("get", &HashMap::get); +} + +} // namespace classes +} // namespace pyg diff --git a/pyg_lib/csrc/classes/hash_map.h b/pyg_lib/csrc/classes/hash_map.h new file mode 100644 index 00000000..50dc881b --- /dev/null +++ b/pyg_lib/csrc/classes/hash_map.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include "hash_map_impl.h" + +namespace pyg { +namespace classes { + +struct HashMap : torch::CustomClassHolder { + public: + HashMap(const at::Tensor& key); + at::Tensor get(const at::Tensor& query); + + private: + std::unique_ptr map_; +}; + +} // namespace classes +} // namespace pyg diff --git a/pyg_lib/csrc/classes/hash_map_impl.h b/pyg_lib/csrc/classes/hash_map_impl.h new file mode 100644 index 00000000..eb0c3689 --- /dev/null +++ b/pyg_lib/csrc/classes/hash_map_impl.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace pyg { +namespace classes { + +struct HashMapImpl { + virtual ~HashMapImpl() = default; + virtual at::Tensor get(const at::Tensor& query) = 0; +}; + +} // namespace classes +} // namespace pyg diff --git a/test/csrc/classes/test_hash_map.cpp b/test/csrc/classes/test_hash_map.cpp index a54400b5..813fbe46 100644 --- a/test/csrc/classes/test_hash_map.cpp +++ b/test/csrc/classes/test_hash_map.cpp @@ -1,13 +1,13 @@ #include #include -#include "pyg_lib/csrc/classes/cpu/hash_map.h" +#include "pyg_lib/csrc/classes/hash_map.h" -TEST(CPUHashMapTest, BasicAssertions) { +TEST(HashMapTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto key = at::tensor({0, 10, 30, 20}, options); - auto map = pyg::classes::CPUHashMap(key); + auto map = pyg::classes::HashMap(key); auto query = at::tensor({30, 10, 20, 40}, options); auto expected = at::tensor({2, 1, 3, -1}, options);