Skip to content

Commit

Permalink
rocr: Add WaitMultiple to core Signal
Browse files Browse the repository at this point in the history
Replaces WaitAny with WaitMultiple to more closely align with the
underlying driver API for waiting on multiple events.

WaitMultiple adds a single parameter, wait_on_all, to the WaitAny
interface providing a single function for waiting on multiple
events when we only need AND and OR semantics for the signal
checking logic.
  • Loading branch information
atgutier committed Jan 9, 2025
1 parent eb1a098 commit 39f4337
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 57 deletions.
9 changes: 9 additions & 0 deletions runtime/hsa-runtime/core/common/hsa_table_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,15 @@ hsa_status_t HSA_API
return amdExtTable->hsa_amd_async_function_fn(callback, arg);
}

// Mirrors Amd Extension Apis
uint32_t HSA_API hsa_amd_signal_wait_all(uint32_t signal_count, hsa_signal_t* signals,
hsa_signal_condition_t* conds, hsa_signal_value_t* values,
uint64_t timeout_hint, hsa_wait_state_t wait_hint,
hsa_signal_value_t* satisfying_values) {
return amdExtTable->hsa_amd_signal_wait_all_fn(signal_count, signals, conds, values, timeout_hint,
wait_hint, satisfying_values);
}

// Mirrors Amd Extension Apis
uint32_t HSA_API
hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* signals,
Expand Down
6 changes: 6 additions & 0 deletions runtime/hsa-runtime/core/inc/hsa_ext_amd_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ hsa_status_t hsa_amd_signal_create(hsa_signal_value_t initial_value, uint32_t nu
const hsa_agent_t* consumers, uint64_t attributes,
hsa_signal_t* signal);

// Mirrors Amd Extension Apis
uint32_t hsa_amd_signal_wait_all(uint32_t signal_count, hsa_signal_t* signals,
hsa_signal_condition_t* conds, hsa_signal_value_t* values,
uint64_t timeout_hint, hsa_wait_state_t wait_hint,
hsa_signal_value_t* satisfying_values);

// Mirrors Amd Extension Apis
uint32_t
hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* signals,
Expand Down
34 changes: 25 additions & 9 deletions runtime/hsa-runtime/core/inc/signal.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// The University of Illinois/NCSA
// Open Source License (NCSA)
//
// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2014-2024, Advanced Micro Devices, Inc. All rights reserved.
//
// Developed by:
//
Expand Down Expand Up @@ -351,14 +351,30 @@ class Signal {
/// Returns NULL for DefaultEvent Type.
virtual HsaEvent* EopEvent() = 0;

/// @brief Waits until any signal in the list satisfies its condition or
/// timeout is reached.
/// Returns the index of a satisfied signal. Returns -1 on timeout and
/// errors.
static uint32_t WaitAny(uint32_t signal_count, const hsa_signal_t* hsa_signals,
const hsa_signal_condition_t* conds, const hsa_signal_value_t* values,
uint64_t timeout_hint, hsa_wait_state_t wait_hint,
hsa_signal_value_t* satisfying_value);
/// @brief Waits until multiple signals in the list satisfy their conditions
/// or a timeout is reached.
/// @param signal_count Number of hsa_signals in the list.
/// @param hsa_signals Pointer to array of HSA signals.
/// @param conds Pointer to array of signal conditions.
/// @param values Pointer to array of signal values.
/// @param timeout Timeout hint value.
/// @param wait_hint Hint about wait state.
/// @param satisfying_values Vector of satisfying values. If \p wait_on_all
/// is false (then we are waiting on any signal in the list) this will contain
/// only the first satisfying value.
/// @param wait_on_all Wait on all signals in the list to satisfy their
/// conditions if true, else wait on any signal in the list to satisfy its
/// condition.
/// @return Return the index of the first signal in the list that satisfies
/// its condition or -1 on a timeout. Note that if \p wait_on_all is true,
/// then all signals in the list satisfy their conditions, thus the index will
/// always be 0.
static uint32_t WaitMultiple(uint32_t signal_count, const hsa_signal_t* hsa_signals,
const hsa_signal_condition_t* conds,
const hsa_signal_value_t* values, uint64_t timeout,
hsa_wait_state_t wait_hint,
std::vector<hsa_signal_value_t>& satisfying_values,
bool wait_on_all);

/// @brief Dedicated funtion to wait on signals that are not of type HSA_EVENTTYPE_SIGNAL
/// these events can only be received by calling the underlying driver (i.e via the hsaKmtWaitOnMultipleEvents_Ext
Expand Down
3 changes: 2 additions & 1 deletion runtime/hsa-runtime/core/runtime/hsa_api_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ void HsaApiTable::Init() {
// they can add preprocessor macros on the new functions

constexpr size_t expected_core_api_table_size = 1016;
constexpr size_t expected_amd_ext_table_size = 584;
constexpr size_t expected_amd_ext_table_size = 592;
constexpr size_t expected_image_ext_table_size = 128;
constexpr size_t expected_finalizer_ext_table_size = 64;
constexpr size_t expected_tools_table_size = 64;
Expand Down Expand Up @@ -412,6 +412,7 @@ void HsaApiTable::UpdateAmdExts() {
amd_ext_api.hsa_amd_profiling_convert_tick_to_system_domain_fn = AMD::hsa_amd_profiling_convert_tick_to_system_domain;
amd_ext_api.hsa_amd_signal_async_handler_fn = AMD::hsa_amd_signal_async_handler;
amd_ext_api.hsa_amd_async_function_fn = AMD::hsa_amd_async_function;
amd_ext_api.hsa_amd_signal_wait_all_fn = AMD::hsa_amd_signal_wait_all;
amd_ext_api.hsa_amd_signal_wait_any_fn = AMD::hsa_amd_signal_wait_any;
amd_ext_api.hsa_amd_queue_cu_set_mask_fn = AMD::hsa_amd_queue_cu_set_mask;
amd_ext_api.hsa_amd_queue_cu_get_mask_fn = AMD::hsa_amd_queue_cu_get_mask;
Expand Down
48 changes: 42 additions & 6 deletions runtime/hsa-runtime/core/runtime/hsa_ext_amd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@
//
////////////////////////////////////////////////////////////////////////////////

#include <new>
#include <typeinfo>
#include <algorithm>
#include <exception>
#include <map>
#include <memory>
#include <new>
#include <set>
#include <typeinfo>
#include <utility>
#include <memory>
#include <map>
#include <vector>

#include "core/inc/agent.h"
Expand Down Expand Up @@ -570,6 +571,35 @@ hsa_status_t hsa_amd_signal_value_pointer(hsa_signal_t hsa_signal,
CATCH;
}

uint32_t hsa_amd_signal_wait_all(uint32_t signal_count, hsa_signal_t* hsa_signals,
hsa_signal_condition_t* conds, hsa_signal_value_t* values,
uint64_t timeout_hint, hsa_wait_state_t wait_hint,
hsa_signal_value_t* satisfying_values) {
TRY;
if (!core::Runtime::runtime_singleton_->IsOpen()) {
assert(false && "hsa_amd_signal_wait_all called while not initialized.");
return 0;
}
// Do not check for signal invalidation. Invalidation may occur during async
// signal handler loop and is not an error.
for (int i = 0; i < signal_count; ++i)
assert(hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid() &&
"Invalid signal.");

std::vector<hsa_signal_value_t> satisfying_values_vec;
satisfying_values_vec.resize(signal_count);
uint32_t first_satysifying_signal_idx =
core::Signal::WaitMultiple(signal_count, hsa_signals, conds, values, timeout_hint, wait_hint,
satisfying_values_vec, true);

if (satisfying_values) {
std::copy(satisfying_values_vec.begin(), satisfying_values_vec.end(), satisfying_values);
}

return first_satysifying_signal_idx;
CATCHRET(uint32_t);
}

uint32_t hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* hsa_signals,
hsa_signal_condition_t* conds, hsa_signal_value_t* values,
uint64_t timeout_hint, hsa_wait_state_t wait_hint,
Expand All @@ -585,8 +615,14 @@ uint32_t hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* hsa_signal
assert(hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid() &&
"Invalid signal.");

return core::Signal::WaitAny(signal_count, hsa_signals, conds, values,
timeout_hint, wait_hint, satisfying_value);
std::vector<hsa_signal_value_t> satisfying_value_vec(1);
uint32_t satisfying_signal_idx =
core::Signal::WaitMultiple(signal_count, hsa_signals, conds, values, timeout_hint, wait_hint,
satisfying_value_vec, false);

if (satisfying_value) *satisfying_value = satisfying_value_vec.at(0);

return satisfying_signal_idx;
CATCHRET(uint32_t);
}

Expand Down
33 changes: 13 additions & 20 deletions runtime/hsa-runtime/core/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1555,7 +1555,7 @@ hsa_status_t Runtime::IPCDetach(void* ptr) {
}

void Runtime::AsyncEventsLoop(void* _eventsInfo) {
struct AsyncEventsInfo* eventsInfo = reinterpret_cast<struct AsyncEventsInfo*>(_eventsInfo);
AsyncEventsInfo* eventsInfo = reinterpret_cast<AsyncEventsInfo*>(_eventsInfo);

auto& async_events_control_ = eventsInfo->control;
auto& async_events_ = eventsInfo->events;
Expand Down Expand Up @@ -1624,26 +1624,19 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) {

while (!async_events_control_.exit) {
// Wait for a signal
hsa_signal_value_t value = 0;
std::vector<hsa_signal_value_t> value(1);
value[0] = 0;
uint32_t index = 0;
uint32_t wait_any = true;
if (eventsInfo->monitor_exceptions) {
index = Signal::WaitAnyExceptions(
uint32_t(async_events_.Size()),
&async_events_.signal_[0],
&async_events_.cond_[0],
&async_events_.value_[0],
&value);
index =
Signal::WaitAnyExceptions(uint32_t(async_events_.Size()), &async_events_.signal_[0],
&async_events_.cond_[0], &async_events_.value_[0], &value[0]);
} else {
if (core::Runtime::runtime_singleton_->flag().wait_any()) {
index = Signal::WaitAny(
uint32_t(async_events_.Size()),
&async_events_.signal_[0],
&async_events_.cond_[0],
&async_events_.value_[0],
uint64_t(-1),
HSA_WAIT_STATE_BLOCKED,
&value);
index = Signal::WaitMultiple(uint32_t(async_events_.Size()), &async_events_.signal_[0],
&async_events_.cond_[0], &async_events_.value_[0], uint64_t(-1),
HSA_WAIT_STATE_BLOCKED, value, false);
} else {
// Skip wake-up signal logic
index = 1;
Expand All @@ -1658,7 +1651,7 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) {
hsa_signal_handle(async_events_control_.wake)->StoreRelaxed(0);
} else if (index != -1) {
if (wait_any) {
processEvent(index, value, wait_any);
processEvent(index, value[0], wait_any);
} else {
index = 0;
}
Expand Down Expand Up @@ -1686,12 +1679,12 @@ void Runtime::AsyncEventsLoop(void* _eventsInfo) {
// Check remaining signals before sleeping.
for (size_t i = index; i < async_events_.Size(); i++) {
hsa_signal_handle sig(async_events_.signal_[i]);
value = atomic::Load(&sig->signal_.value, std::memory_order_relaxed);
if (checkCondition(async_events_.cond_[i], value, async_events_.value_[i])) {
value[0] = atomic::Load(&sig->signal_.value, std::memory_order_relaxed);
if (checkCondition(async_events_.cond_[i], value[0], async_events_.value_[i])) {
if (i == 0) {
hsa_signal_handle(async_events_control_.wake)->StoreRelaxed(0);
} else {
if (!processEvent(i, value, wait_any)) {
if (!processEvent(i, value[0], wait_any)) {
i--;
}
}
Expand Down
50 changes: 32 additions & 18 deletions runtime/hsa-runtime/core/runtime/signal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@
//
// The University of Illinois/NCSA
// Open Source License (NCSA)
//
// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved.
//
//
// Copyright (c) 2014-2024, Advanced Micro Devices, Inc. All rights reserved.
//
// Developed by:
//
//
// AMD Research and AMD HSA Software Development
//
//
// Advanced Micro Devices, Inc.
//
//
// www.amd.com
//
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to
// deal with the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
//
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimers.
// - Redistributions in binary form must reproduce the above copyright
Expand All @@ -29,7 +29,7 @@
// nor the names of its contributors may be used to endorse or promote
// products derived from this Software without specific prior written
// permission.
//
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
Expand All @@ -46,6 +46,9 @@
#include "core/inc/signal.h"

#include <algorithm>
#include <numeric>
#include <vector>

#include "core/util/timer.h"
#include "core/inc/runtime.h"

Expand Down Expand Up @@ -177,10 +180,11 @@ Signal::~Signal() {
}
}

uint32_t Signal::WaitAny(uint32_t signal_count, const hsa_signal_t* hsa_signals,
const hsa_signal_condition_t* conds, const hsa_signal_value_t* values,
uint64_t timeout, hsa_wait_state_t wait_hint,
hsa_signal_value_t* satisfying_value) {
uint32_t Signal::WaitMultiple(uint32_t signal_count, const hsa_signal_t* hsa_signals,
const hsa_signal_condition_t* conds, const hsa_signal_value_t* values,
uint64_t timeout, hsa_wait_state_t wait_hint,
std::vector<hsa_signal_value_t>& satisfying_values,
bool wait_on_all) {
hsa_signal_handle* signals =
reinterpret_cast<hsa_signal_handle*>(const_cast<hsa_signal_t*>(hsa_signals));

Expand Down Expand Up @@ -251,10 +255,14 @@ uint32_t Signal::WaitAny(uint32_t signal_count, const hsa_signal_t* hsa_signals,
timer::duration_from_seconds<timer::fast_clock::duration>(
double(timeout) / double(hsa_freq));

bool condition_met = false;
std::vector<uint32_t> unmet_condition_ids(signal_count);
std::iota(unmet_condition_ids.begin(), unmet_condition_ids.end(), 0);

while (true) {
// Cannot mwaitx - polling multiple signals
for (uint32_t i = 0; i < signal_count; i++) {
for (auto it = unmet_condition_ids.begin(); it != unmet_condition_ids.end();) {
auto i = *it;
bool condition_met = false;
if (!signals[i]->IsValid())
return uint32_t(-1);

Expand Down Expand Up @@ -282,8 +290,14 @@ uint32_t Signal::WaitAny(uint32_t signal_count, const hsa_signal_t* hsa_signals,
return uint32_t(-1);
}
if (condition_met) {
if (satisfying_value != NULL) *satisfying_value = value;
return i;
it = unmet_condition_ids.erase(it);
satisfying_values[i] = value;
if (!wait_on_all)
return i;
else if (unmet_condition_ids.empty())
return 0;
} else {
++it;
}
}

Expand All @@ -306,7 +320,7 @@ uint32_t Signal::WaitAny(uint32_t signal_count, const hsa_signal_t* hsa_signals,
uint64_t ct=timer::duration_cast<std::chrono::milliseconds>(
time_remaining).count();
wait_ms = (ct>0xFFFFFFFEu) ? 0xFFFFFFFEu : ct;
hsaKmtWaitOnMultipleEvents_Ext(evts, unique_evts, false, wait_ms, event_age);
hsaKmtWaitOnMultipleEvents_Ext(evts, unique_evts, wait_on_all, wait_ms, event_age);
}
}

Expand Down
1 change: 1 addition & 0 deletions runtime/hsa-runtime/hsacore.so.def
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ global:
hsa_amd_profiling_get_async_copy_time;
hsa_amd_profiling_convert_tick_to_system_domain;
hsa_amd_signal_create;
hsa_amd_signal_wait_all;
hsa_amd_signal_wait_any;
hsa_amd_signal_async_handler;
hsa_amd_async_function;
Expand Down
3 changes: 2 additions & 1 deletion runtime/hsa-runtime/inc/hsa_api_trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// The University of Illinois/NCSA
// Open Source License (NCSA)
//
// Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2014-2024, Advanced Micro Devices, Inc. All rights reserved.
//
// Developed by:
//
Expand Down Expand Up @@ -205,6 +205,7 @@ struct AmdExtTable {
decltype(hsa_amd_profiling_convert_tick_to_system_domain)* hsa_amd_profiling_convert_tick_to_system_domain_fn;
decltype(hsa_amd_signal_async_handler)* hsa_amd_signal_async_handler_fn;
decltype(hsa_amd_async_function)* hsa_amd_async_function_fn;
decltype(hsa_amd_signal_wait_all)* hsa_amd_signal_wait_all_fn;
decltype(hsa_amd_signal_wait_any)* hsa_amd_signal_wait_any_fn;
decltype(hsa_amd_queue_cu_set_mask)* hsa_amd_queue_cu_set_mask_fn;
decltype(hsa_amd_memory_pool_get_info)* hsa_amd_memory_pool_get_info_fn;
Expand Down
Loading

0 comments on commit 39f4337

Please sign in to comment.