Skip to content

Commit

Permalink
Make CallbackOnLastSignal more readable
Browse files Browse the repository at this point in the history
Summary:
facebookincubator#9548

Simpler to understand.

Reviewed By: bikramSingh91, marxhxxx

Differential Revision: D56363766
  • Loading branch information
Daniel Munoz authored and facebook-github-bot committed Apr 19, 2024
1 parent 8505b03 commit ff5fc1b
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 40 deletions.
147 changes: 107 additions & 40 deletions velox/dwio/common/UnitLoaderTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,74 +62,141 @@ inline std::optional<Measure> measureBlockedOnIo(
// This class can create many callbacks that can be distributed to unit loader
// factories. Only when the last created callback is activated, this class will
// emit the original callback.
// If the callbacks created are never activated (because of an exception for
// example), this will still work, since they will emit the signal on the
// destructor.
// The class expects that all callbacks are created before they start emiting
// signals.
// If the callback objects created are never explicitly called (because of an
// exception for example), the callback object will do the call in the
// destructor, guaranteeting the call.
class CallbackOnLastSignal {
class EnsureCallOnlyOnce {
class Callable {
public:
explicit EnsureCallOnlyOnce(std::function<void()> cb)
: cb_{std::move(cb)} {}
virtual ~Callable() {}
virtual void call() = 0;
};

EnsureCallOnlyOnce(const EnsureCallOnlyOnce& other) = delete;
EnsureCallOnlyOnce(EnsureCallOnlyOnce&& other) = delete;
EnsureCallOnlyOnce& operator=(const EnsureCallOnlyOnce& other) = delete;
EnsureCallOnlyOnce& operator=(EnsureCallOnlyOnce&& other) noexcept = delete;
class CallableFunction : public Callable {
public:
explicit CallableFunction(std::function<void()> cb) : cb_{std::move(cb)} {}

void call() {
folly::call_once(called_, cb_);
void call() override {
if (cb_) {
// Could be null
cb_();
}
}

private:
std::function<void()> cb_;
folly::once_flag called_;
};

class Caller {
// This class will ensure that the contained callback is only called once.
class CallOnce : public Callable {
public:
Caller(
std::shared_ptr<EnsureCallOnlyOnce> callWrapper,
std::shared_ptr<std::atomic_size_t> callersCount)
: callersCount_{std::move(callersCount)},
callOnce_{std::move(callWrapper)} {
++(*callersCount_);
explicit CallOnce(std::shared_ptr<Callable> cb) : cb_{std::move(cb)} {}

CallOnce(const CallOnce& other) = delete;
CallOnce(CallOnce&& other) = delete;
CallOnce& operator=(const CallOnce& other) = delete;
CallOnce& operator=(CallOnce&& other) noexcept = delete;

void call() override {
folly::call_once(called_, [&]() { cb_->call(); });
}

Caller(Caller&& other) noexcept
: callersCount_{std::move(other.callersCount_)},
callOnce_{std::move(other.callOnce_)} {}
private:
std::shared_ptr<Callable> cb_;
folly::once_flag called_;
};

~Caller() {
if (--(*callersCount_) == 0) {
callOnce_->call();
// This class will ensure that only the call from the last caller will go
// through.
class CallOnCountZero : public Callable {
public:
CallOnCountZero(
std::shared_ptr<std::atomic_size_t> callsLeft,
std::shared_ptr<Callable> cb)
: callsLeft_{std::move(callsLeft)}, cb_{std::move(cb)} {}

CallOnCountZero(const CallOnCountZero& other) = delete;
CallOnCountZero(CallOnCountZero&& other) = delete;
CallOnCountZero& operator=(const CallOnCountZero& other) = delete;
CallOnCountZero& operator=(CallOnCountZero&& other) noexcept = delete;

void call() override {
if (*callsLeft_ > 0) {
--*(callsLeft_);
}
if (*callsLeft_ == 0) {
cb_->call();
}
}

Caller(const Caller&) = delete;
Caller& operator=(const Caller&) = delete;
Caller& operator=(Caller&& other) noexcept = delete;
private:
std::shared_ptr<std::atomic_size_t> callsLeft_;
std::shared_ptr<Callable> cb_;
};

// This class will ensure that the contained callback is called when the
// operator() is invoked, or when the object is destructed, whatever comes
// first.
class EnsureCall : public Callable {
public:
explicit EnsureCall(std::shared_ptr<Callable> cb)
: cb_{std::make_shared<CallOnce>(std::move(cb))} {}

EnsureCall(const EnsureCall& other) = delete;
EnsureCall(EnsureCall&& other) = delete;
EnsureCall& operator=(const EnsureCall& other) = delete;
EnsureCall& operator=(EnsureCall&& other) noexcept = delete;

~EnsureCall() override {
cb_->call();
}

void call() override {
cb_->call();
}

private:
std::shared_ptr<Callable> cb_;
};

void operator()() {}
class CountCaller {
public:
CountCaller(
std::shared_ptr<Callable> cb,
std::shared_ptr<std::atomic_size_t> callsLeft)
: cb_{std::move(cb)} {
++(*callsLeft);
}

void operator()() {
cb_->call();
}

private:
std::shared_ptr<std::atomic_size_t> callersCount_;
std::shared_ptr<EnsureCallOnlyOnce> callOnce_;
std::shared_ptr<Callable> cb_;
};

public:
CallbackOnLastSignal(std::function<void()> cb)
: ensureCallOnce_{std::make_shared<EnsureCallOnlyOnce>(std::move(cb))},
callersCount_{std::make_shared<std::atomic_size_t>(0)} {}
explicit CallbackOnLastSignal(std::function<void()> cb)
: callsLeft_{std::make_shared<std::atomic_size_t>(0)},
cb_{cb ? std::make_shared<CallOnCountZero>(
callsLeft_,
std::make_shared<CallOnce>(std::make_shared<EnsureCall>(
std::make_shared<CallableFunction>(std::move(cb)))))
: nullptr} {}

std::function<void()> getCallback() const {
return Caller{ensureCallOnce_, callersCount_};
if (!cb_) {
return nullptr;
}
return CountCaller{
std::make_shared<CallOnce>(std::make_shared<EnsureCall>(cb_)),
callsLeft_};
}

private:
std::shared_ptr<EnsureCallOnlyOnce> ensureCallOnce_;
std::shared_ptr<std::atomic_size_t> callersCount_;
std::shared_ptr<std::atomic_size_t> callsLeft_;
std::shared_ptr<Callable> cb_;
};

} // namespace facebook::velox::dwio::common::unit_loader_tools
6 changes: 6 additions & 0 deletions velox/dwio/common/tests/UnitLoaderToolsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ TEST(UnitLoaderToolsTests, NoCallbacksCreated) {
EXPECT_EQ(callCount, 1);
}

TEST(UnitLoaderToolsTests, SupportsNullCallbacks) {
CallbackOnLastSignal callback(nullptr);
auto cb = callback.getCallback();
EXPECT_TRUE(cb == nullptr);
}

TEST(UnitLoaderToolsTests, NoExplicitCalls) {
std::atomic_size_t callCount = 0;
{
Expand Down

0 comments on commit ff5fc1b

Please sign in to comment.