diff --git a/source/adapters/level_zero/v2/event.cpp b/source/adapters/level_zero/v2/event.cpp index edccd7f429..d2332ddafb 100644 --- a/source/adapters/level_zero/v2/event.cpp +++ b/source/adapters/level_zero/v2/event.cpp @@ -87,17 +87,17 @@ uint64_t *event_profiling_data_t::eventEndTimestampAddr() { return &recordEventEndTimestamp; } -ur_event_handle_t_::ur_event_handle_t_(ur_context_handle_t hContext, - ze_event_handle_t hZeEvent, - v2::event_flags_t flags) - : hContext(hContext), hZeEvent(hZeEvent), flags(flags), - profilingData(hZeEvent) {} +ur_event_handle_t_::ur_event_handle_t_( + ur_context_handle_t hContext, ur_event_handle_t_::event_variant hZeEvent, + v2::event_flags_t flags, v2::event_pool *pool) + : hContext(hContext), event_pool(pool), hZeEvent(std::move(hZeEvent)), + flags(flags), profilingData(getZeEvent()) {} void ur_event_handle_t_::resetQueueAndCommand(ur_queue_handle_t hQueue, ur_command_t commandType) { this->hQueue = hQueue; this->commandType = commandType; - profilingData = event_profiling_data_t(hZeEvent); + profilingData = event_profiling_data_t(getZeEvent()); } void ur_event_handle_t_::recordStartTimestamp() { @@ -123,13 +123,16 @@ void ur_event_handle_t_::reset() { // consider make an abstraction for regular/counter based // events if there's more of this type of conditions if (!(flags & v2::EVENT_FLAGS_COUNTER)) { - zeEventHostReset(hZeEvent); + zeEventHostReset(getZeEvent()); } } ze_event_handle_t ur_event_handle_t_::getZeEvent() const { - assert(hZeEvent); - return hZeEvent; + if (event_pool) { + return std::get(hZeEvent).get(); + } else { + return std::get(hZeEvent).get(); + } } ur_result_t ur_event_handle_t_::retain() { @@ -138,7 +141,7 @@ ur_result_t ur_event_handle_t_::retain() { } ur_result_t ur_event_handle_t_::releaseDeferred() { - assert(zeEventQueryStatus(hZeEvent) == ZE_RESULT_SUCCESS); + assert(zeEventQueryStatus(getZeEvent()) == ZE_RESULT_SUCCESS); assert(RefCount.load() == 0); return this->forceRelease(); @@ -176,7 +179,7 @@ bool ur_event_handle_t_::isProfilingEnabled() const { std::pair ur_event_handle_t_::getEventEndTimestampAndHandle() { - return {profilingData.eventEndTimestampAddr(), hZeEvent}; + return {profilingData.eventEndTimestampAddr(), getZeEvent()}; } ur_queue_handle_t ur_event_handle_t_::getQueue() const { return hQueue; } @@ -185,29 +188,33 @@ ur_context_handle_t ur_event_handle_t_::getContext() const { return hContext; } ur_command_t ur_event_handle_t_::getCommandType() const { return commandType; } -ur_pooled_event_t::ur_pooled_event_t( +ur_event_handle_t_::ur_event_handle_t_( ur_context_handle_t hContext, v2::raii::cache_borrowed_event eventAllocation, v2::event_pool *pool) - : ur_event_handle_t_(hContext, eventAllocation.get(), pool->getFlags()), - zeEvent(std::move(eventAllocation)), pool(pool) {} - -ur_result_t ur_pooled_event_t::forceRelease() { - pool->free(this); - return UR_RESULT_SUCCESS; -} + : ur_event_handle_t_(hContext, std::move(eventAllocation), pool->getFlags(), + pool) {} -ur_native_event_t::ur_native_event_t( - ur_native_handle_t hNativeEvent, ur_context_handle_t hContext, +ur_event_handle_t_::ur_event_handle_t_( + ur_context_handle_t hContext, ur_native_handle_t hNativeEvent, const ur_event_native_properties_t *pProperties) : ur_event_handle_t_( hContext, - reinterpret_cast(hNativeEvent), v2::EVENT_FLAGS_PROFILING_ENABLED /* TODO: this follows legacy adapter logic, we could check this with zeEventGetPool */), - zeEvent(reinterpret_cast(hNativeEvent), - pProperties ? pProperties->isNativeHandleOwned : false) {} - -ur_result_t ur_native_event_t::forceRelease() { - zeEvent.release(); - delete this; + v2::raii::ze_event_handle_t{ + reinterpret_cast(hNativeEvent), + pProperties ? pProperties->isNativeHandleOwned : false}, + v2::EVENT_FLAGS_PROFILING_ENABLED /* TODO: this follows legacy adapter + logic, we could check this with + zeEventGetPool */ + , + nullptr) {} + +ur_result_t ur_event_handle_t_::forceRelease() { + if (event_pool) { + event_pool->free(this); + } else { + std::get(hZeEvent).release(); + delete this; + } return UR_RESULT_SUCCESS; } @@ -389,7 +396,7 @@ urEventCreateWithNativeHandle(ur_native_handle_t hNativeEvent, *phEvent = hContext->nativeEventsPool.allocate(); ZE2UR_CALL(zeEventHostSignal, ((*phEvent)->getZeEvent())); } else { - *phEvent = new ur_native_event_t(hNativeEvent, hContext, pProperties); + *phEvent = new ur_event_handle_t_(hContext, hNativeEvent, pProperties); } return UR_RESULT_SUCCESS; } catch (...) { diff --git a/source/adapters/level_zero/v2/event.hpp b/source/adapters/level_zero/v2/event.hpp index 9e2331c649..f4a2bb8c11 100644 --- a/source/adapters/level_zero/v2/event.hpp +++ b/source/adapters/level_zero/v2/event.hpp @@ -47,15 +47,24 @@ struct event_profiling_data_t { struct ur_event_handle_t_ : _ur_object { public: - ur_event_handle_t_(ur_context_handle_t hContext, ze_event_handle_t hZeEvent, - v2::event_flags_t flags); + // cache_borrowed_event is used for pooled events, whilst ze_event_handle_t is + // used for native events + using event_variant = + std::variant; + + ur_event_handle_t_(ur_context_handle_t hContext, + v2::raii::cache_borrowed_event eventAllocation, + v2::event_pool *pool); + + ur_event_handle_t_(ur_context_handle_t hContext, + ur_native_handle_t hNativeEvent, + const ur_event_native_properties_t *pProperties); // Set the queue and command that this event is associated with void resetQueueAndCommand(ur_queue_handle_t hQueue, ur_command_t commandType); // releases event immediately - virtual ur_result_t forceRelease() = 0; - virtual ~ur_event_handle_t_() = default; + ur_result_t forceRelease(); void reset(); ze_event_handle_t getZeEvent() const; @@ -97,11 +106,16 @@ struct ur_event_handle_t_ : _ur_object { uint64_t getEventStartTimestmap() const; uint64_t getEventEndTimestamp(); +private: + ur_event_handle_t_(ur_context_handle_t hContext, event_variant hZeEvent, + v2::event_flags_t flags, v2::event_pool *pool); + protected: ur_context_handle_t hContext; - // non-owning handle to the L0 event - const ze_event_handle_t hZeEvent; + // Pool is used if and only if this is a pooled event + v2::event_pool *event_pool = nullptr; + event_variant hZeEvent; // queue and commandType that this event is associated with, set by enqueue // commands @@ -111,26 +125,3 @@ struct ur_event_handle_t_ : _ur_object { v2::event_flags_t flags; event_profiling_data_t profilingData; }; - -struct ur_pooled_event_t : ur_event_handle_t_ { - ur_pooled_event_t(ur_context_handle_t hContext, - v2::raii::cache_borrowed_event eventAllocation, - v2::event_pool *pool); - - ur_result_t forceRelease() override; - -private: - v2::raii::cache_borrowed_event zeEvent; - v2::event_pool *pool; -}; - -struct ur_native_event_t : ur_event_handle_t_ { - ur_native_event_t(ur_native_handle_t hNativeEvent, - ur_context_handle_t hContext, - const ur_event_native_properties_t *pProperties); - - ur_result_t forceRelease() override; - -private: - v2::raii::ze_event_handle_t zeEvent; -}; diff --git a/source/adapters/level_zero/v2/event_pool.cpp b/source/adapters/level_zero/v2/event_pool.cpp index d7e1d451ac..d9639a1a6d 100644 --- a/source/adapters/level_zero/v2/event_pool.cpp +++ b/source/adapters/level_zero/v2/event_pool.cpp @@ -17,7 +17,7 @@ namespace v2 { static constexpr size_t EVENTS_BURST = 64; -ur_pooled_event_t *event_pool::allocate() { +ur_event_handle_t event_pool::allocate() { TRACK_SCOPE_LATENCY("event_pool::allocate"); std::unique_lock lock(*mutex); @@ -42,7 +42,7 @@ ur_pooled_event_t *event_pool::allocate() { return event; } -void event_pool::free(ur_pooled_event_t *event) { +void event_pool::free(ur_event_handle_t event) { TRACK_SCOPE_LATENCY("event_pool::free"); std::unique_lock lock(*mutex); diff --git a/source/adapters/level_zero/v2/event_pool.hpp b/source/adapters/level_zero/v2/event_pool.hpp index fd029f09a7..faafab2a86 100644 --- a/source/adapters/level_zero/v2/event_pool.hpp +++ b/source/adapters/level_zero/v2/event_pool.hpp @@ -41,10 +41,10 @@ class event_pool { event_pool &operator=(const event_pool &) = delete; // Allocate an event from the pool. Thread safe. - ur_pooled_event_t *allocate(); + ur_event_handle_t allocate(); // Free an event back to the pool. Thread safe. - void free(ur_pooled_event_t *event); + void free(ur_event_handle_t event); event_provider *getProvider() const; event_flags_t getFlags() const; @@ -53,8 +53,8 @@ class event_pool { ur_context_handle_t hContext; std::unique_ptr provider; - std::deque events; - std::vector freelist; + std::deque events; + std::vector freelist; std::unique_ptr mutex; };