From b277117234bd9dd32bebf945d87018972a0212e7 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Tue, 5 Nov 2024 17:16:59 -0500 Subject: [PATCH] change MemoryManager in ExecutionContext.hpp to use smart pointers instead of raw pointers --- runtime/lib/capi/ExecutionContext.hpp | 41 +++++++++++++++++++------ runtime/lib/capi/RuntimeCAPI.cpp | 16 +++------- runtime/tests/Test_LightningCoreQIS.cpp | 2 -- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/runtime/lib/capi/ExecutionContext.hpp b/runtime/lib/capi/ExecutionContext.hpp index 68777ccc75..315e49ec8c 100644 --- a/runtime/lib/capi/ExecutionContext.hpp +++ b/runtime/lib/capi/ExecutionContext.hpp @@ -40,38 +40,59 @@ extern "C" void __catalyst_inactive_callback(int64_t identifier, int64_t argc, i class MemoryManager final { private: - std::unordered_set _impl; + std::unordered_set> _impl; std::mutex mu; // To guard the memory manager public: explicit MemoryManager() { _impl.reserve(1024); }; - ~MemoryManager() + std::shared_ptr create(size_t size) { - // Lock the mutex to protect _impl free - std::lock_guard lock(mu); - for (auto allocation : _impl) { - free(allocation); - } + std::shared_ptr p(std::malloc(size), [](void *ptr) { free(ptr); }); + this->insert(p); + return p; } - void insert(void *ptr) + std::shared_ptr create_aligned(size_t alignment, size_t size) + { + std::shared_ptr p(std::aligned_alloc(alignment, size), [](void *ptr) { free(ptr); }); + this->insert(p); + return p; + } + + void insert(std::shared_ptr ptr) { // Lock the mutex to protect _impl update std::lock_guard lock(mu); _impl.insert(ptr); } + void erase(void *ptr) { // Lock the mutex to protect _impl update std::lock_guard lock(mu); - _impl.erase(ptr); + + std::shared_ptr target; + for (std::shared_ptr sharedP : _impl) { + if (sharedP.get() == ptr) { + target = sharedP; + } + } + _impl.erase(target); } + bool contains(void *ptr) { // Lock the mutex to protect _impl update std::lock_guard lock(mu); - return _impl.contains(ptr); + + bool result = false; + for (std::shared_ptr sharedP : _impl) { + if (sharedP.get() == ptr) { + result = true; + } + } + return result; } }; diff --git a/runtime/lib/capi/RuntimeCAPI.cpp b/runtime/lib/capi/RuntimeCAPI.cpp index 81b2b2affb..72a26763d3 100644 --- a/runtime/lib/capi/RuntimeCAPI.cpp +++ b/runtime/lib/capi/RuntimeCAPI.cpp @@ -156,16 +156,14 @@ void __catalyst__host__rt__unrecoverable_error() void *_mlir_memref_to_llvm_alloc(size_t size) { - void *ptr = malloc(size); - CTX->getMemoryManager()->insert(ptr); - return ptr; + std::shared_ptr ptr = CTX->getMemoryManager()->create(size); + return ptr.get(); } void *_mlir_memref_to_llvm_aligned_alloc(size_t alignment, size_t size) { - void *ptr = aligned_alloc(alignment, size); - CTX->getMemoryManager()->insert(ptr); - return ptr; + std::shared_ptr ptr = CTX->getMemoryManager()->create_aligned(alignment, size); + return ptr.get(); } bool _mlir_memory_transfer(void *ptr) @@ -177,11 +175,7 @@ bool _mlir_memory_transfer(void *ptr) return true; } -void _mlir_memref_to_llvm_free(void *ptr) -{ - CTX->getMemoryManager()->erase(ptr); - free(ptr); -} +void _mlir_memref_to_llvm_free(void *ptr) { CTX->getMemoryManager()->erase(ptr); } void __catalyst__rt__print_string(char *string) { diff --git a/runtime/tests/Test_LightningCoreQIS.cpp b/runtime/tests/Test_LightningCoreQIS.cpp index a165db4ceb..fe959881ba 100644 --- a/runtime/tests/Test_LightningCoreQIS.cpp +++ b/runtime/tests/Test_LightningCoreQIS.cpp @@ -491,7 +491,6 @@ TEST_CASE("Test memory transfer in rt", "[CoreQIS]") bool is_in_rt = _mlir_memory_transfer(a); CHECK(is_in_rt); __catalyst__rt__finalize(); - free(a); } TEST_CASE("Test memory transfer not in rt", "[CoreQIS]") @@ -501,7 +500,6 @@ TEST_CASE("Test memory transfer not in rt", "[CoreQIS]") bool is_in_rt = _mlir_memory_transfer(a); CHECK(!is_in_rt); __catalyst__rt__finalize(); - free(a); } TEST_CASE("Test __catalyst__qis__Measure", "[CoreQIS]")