Skip to content

Commit

Permalink
Fix coroutines crash with MSVC (#522)
Browse files Browse the repository at this point in the history
MSVC doesn't like when coroutine is destroyed inside final suspender's await_suspend().
Instead return false from it's await_ready() allowing compiler's injected code to destroy coroutine automatically.
In this case we also need to make sure that out Coroutine<> object doesn't destroy coroutine automatically since it will result in double free
(we call completion callback in final suspender's await_ready() which immediately destroys Coroutine<>).
  • Loading branch information
equeim committed Sep 6, 2024
1 parent daf213c commit 524cbe9
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 47 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

### Fixed
- Performance regression on Windows (and potential performance improvements on other platforms)
- Crash on Windows

## [2.7.0] - 2024-08-31
### Added
Expand Down
33 changes: 19 additions & 14 deletions src/coroutines/coroutines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,16 @@ namespace tremotesf::impl {
return true;
}

std::coroutine_handle<> CoroutinePromiseBase::onPerformedFinalSuspendBase() {
if (mOwningStandaloneCoroutine->completeCancellation()) {
return std::noop_coroutine();
}
return mParentCoroutineHandle;
}

void CoroutinePromiseBase::abortNoParent(std::coroutine_handle<> handle) {
fatal().log("No parent coroutine when completing coroutine {}", handle.address());
Q_UNREACHABLE();
}

std::coroutine_handle<> CoroutinePromise<void>::onPerformedFinalSuspend() {
if (const auto handle = onPerformedFinalSuspendBase(); handle) {
return handle;
}
mOwningStandaloneCoroutine->invokeCompletionCallback(std::move(mUnhandledException));
return std::noop_coroutine();
void CoroutinePromise<void>::invokeCompletionCallbackForStandaloneCoroutine() {
// Completion callback will destroy Coroutine<> object, but coroutine itself will be destroyed later by compiler's injected machinery
// because CoroutinePromiseFinalSuspendAwaiter::await_ready will return false
// Pass true for coroutineWillBeDestroyedAutomatically parameter here so that Coroutine<>'s destructor won't destroy coroutine resulting in double free
mOwningStandaloneCoroutine->invokeCompletionCallback(std::move(mUnhandledException), true);
}

void StandaloneCoroutine::cancel() {
Expand All @@ -76,11 +68,24 @@ namespace tremotesf::impl {
return false;
case CancellationState::Cancelling:
mCancellationState = CancellationState::Cancelled;
invokeCompletionCallback({});
invokeCompletionCallback({}, false);
return true;
case CancellationState::Cancelled:
return true;
}
return false;
}

void StandaloneCoroutine::invokeCompletionCallback(
std::exception_ptr&& unhandledException, bool coroutineWillBeDestroyedAutomatically
) {
if (coroutineWillBeDestroyedAutomatically) {
mCoroutine.mHandle = nullptr;
}
mCompletionCallback(std::move(unhandledException));
}

void StandaloneCoroutine::setCompletionCallback(std::function<void(std::exception_ptr)>&& callback) {
mCompletionCallback = std::move(callback);
}
}
68 changes: 35 additions & 33 deletions src/coroutines/coroutines.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,13 @@ namespace tremotesf {
};

namespace impl {
class CoroutinePromiseFinalSuspendAwaiter;

class CoroutinePromiseBase {
public:
inline ~CoroutinePromiseBase() = default;
Q_DISABLE_COPY_MOVE(CoroutinePromiseBase)

// promise object contract begin
inline std::suspend_always initial_suspend() { return {}; }
inline CoroutinePromiseFinalSuspendAwaiter final_suspend() noexcept;
inline void unhandled_exception() { mUnhandledException = std::current_exception(); }
// promise object contract end

Expand All @@ -87,7 +84,6 @@ namespace tremotesf {
inline void setOwningStandaloneCoroutine(StandaloneCoroutine* root) { mOwningStandaloneCoroutine = root; }
void setParentCoroutineHandle(std::coroutine_handle<> parentCoroutineHandle);

std::coroutine_handle<> onPerformedFinalSuspendBase();
void rethrowException() {
if (mUnhandledException) {
std::rethrow_exception(mUnhandledException);
Expand All @@ -97,8 +93,7 @@ namespace tremotesf {
protected:
inline CoroutinePromiseBase() = default;

[[noreturn]]
static void abortNoParent(std::coroutine_handle<> handle);
[[noreturn]] static void abortNoParent(std::coroutine_handle<> handle);

std::coroutine_handle<> mParentCoroutineHandle{};
std::variant<std::monostate, JustCompleteCancellation, std::function<void()>>
Expand All @@ -107,6 +102,23 @@ namespace tremotesf {
StandaloneCoroutine* mOwningStandaloneCoroutine{};
};

class CoroutinePromiseFinalSuspendAwaiter final {
public:
explicit CoroutinePromiseFinalSuspendAwaiter(std::coroutine_handle<> parentCoroutine)
: mParentCoroutine(std::move(parentCoroutine)) {}

// If there is no parent coroutine then await_ready returns false which causes our coroutine to be destroyed
// Otherwise control is transferred to parent coroutine, which destroys CoroutineAwaiter and therefore our coroutine

inline bool await_ready() noexcept { return !mParentCoroutine; }

std::coroutine_handle<> await_suspend(std::coroutine_handle<>) noexcept { return mParentCoroutine; }
inline void await_resume() noexcept {}

private:
std::coroutine_handle<> mParentCoroutine;
};

template<CoroutineReturnValue T>
class CoroutinePromise final : public CoroutinePromiseBase {
public:
Expand All @@ -117,14 +129,13 @@ namespace tremotesf {
inline Coroutine<T> get_return_object() { return Coroutine<T>(mCoroutineHandle); }
inline void return_value(const T& valueToReturn) { mValue = valueToReturn; }
inline void return_value(T&& valueToReturn) { mValue = std::move(valueToReturn); }
// promise object contract end

inline std::coroutine_handle<> onPerformedFinalSuspend() {
if (const auto handle = onPerformedFinalSuspendBase(); handle) {
return handle;
inline CoroutinePromiseFinalSuspendAwaiter final_suspend() noexcept {
if (mParentCoroutineHandle) {
return CoroutinePromiseFinalSuspendAwaiter(mParentCoroutineHandle);
}
abortNoParent(mCoroutineHandle);
}
// promise object contract end

inline T takeValueOrRethrowException() {
rethrowException();
Expand All @@ -147,26 +158,20 @@ namespace tremotesf {
return Coroutine<void>(std::coroutine_handle<CoroutinePromise<void>>::from_promise(*this));
}
inline void return_void() {}
inline CoroutinePromiseFinalSuspendAwaiter final_suspend() noexcept {
if (mParentCoroutineHandle) {
return CoroutinePromiseFinalSuspendAwaiter(mParentCoroutineHandle);
}
invokeCompletionCallbackForStandaloneCoroutine();
return CoroutinePromiseFinalSuspendAwaiter(nullptr);
}
// promise object contract end

std::coroutine_handle<> onPerformedFinalSuspend();
void invokeCompletionCallbackForStandaloneCoroutine();

inline void takeValueOrRethrowException() { rethrowException(); }
};

class CoroutinePromiseFinalSuspendAwaiter final {
public:
inline bool await_ready() noexcept { return false; }

template<std::derived_from<CoroutinePromiseBase> Promise>
inline std::coroutine_handle<> await_suspend(std::coroutine_handle<Promise> handle) noexcept {
return handle.promise().onPerformedFinalSuspend();
}
inline void await_resume() noexcept {}
};

CoroutinePromiseFinalSuspendAwaiter CoroutinePromiseBase::final_suspend() noexcept { return {}; }

template<CoroutineReturnValue T>
class CoroutineAwaiter final {
public:
Expand Down Expand Up @@ -205,8 +210,7 @@ namespace tremotesf {
};

template<typename Promise>
[[nodiscard]]
bool startAwaiting(std::coroutine_handle<Promise> handle) {
[[nodiscard]] bool startAwaiting(std::coroutine_handle<Promise> handle) {
if constexpr (std::derived_from<Promise, CoroutinePromiseBase>) {
auto& promise = handle.promise();
return promise.onStartedAwaiting(CoroutinePromiseBase::JustCompleteCancellation{});
Expand Down Expand Up @@ -241,12 +245,10 @@ namespace tremotesf {
void cancel();
bool completeCancellation();

inline void invokeCompletionCallback(std::exception_ptr&& unhandledException) {
mCompletionCallback(std::move(unhandledException));
}
inline void setCompletionCallback(std::function<void(std::exception_ptr)>&& callback) {
mCompletionCallback = std::move(callback);
}
void invokeCompletionCallback(
std::exception_ptr&& unhandledException, bool coroutineWillBeDestroyedAutomatically
);
void setCompletionCallback(std::function<void(std::exception_ptr)>&& callback);

private:
Coroutine<void> mCoroutine;
Expand Down

0 comments on commit 524cbe9

Please sign in to comment.