Skip to content

Commit

Permalink
Make CallbackOnLastSignal more readable
Browse files Browse the repository at this point in the history
Summary:
facebookincubator#9548

Simpler to understand.

Reviewed By: bikramSingh91, marxhxxx

Differential Revision: D56363766
  • Loading branch information
Daniel Munoz authored and facebook-github-bot committed Apr 19, 2024
1 parent cb87bc8 commit 95be9d8
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 48 deletions.
156 changes: 108 additions & 48 deletions velox/dwio/common/UnitLoaderTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

#pragma once

#include <atomic>
#include <chrono>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <vector>

#include "folly/synchronization/CallOnce.h"

namespace facebook::velox::dwio::common::unit_loader_tools {

class Measure {
Expand Down Expand Up @@ -59,84 +62,141 @@ inline std::optional<Measure> measureBlockedOnIo(
// This class can create many callbacks that can be distributed to unit loader
// factories. Only when the last created callback is activated, this class will
// emit the original callback.
// If the callbacks created are never activated (because of an exception for
// example), this will still work, since they will emit the signal on the
// destructor.
// The class expects that all callbacks are created before they start emiting
// signals.
// If the callback objects created are never explicitly called (because of an
// exception for example), the callback object will do the call in the
// destructor, guaranteeting the call.
class CallbackOnLastSignal {
class EnsureCallOnlyOnce {
class Callable {
public:
EnsureCallOnlyOnce(std::function<void()> cb) : cb_{std::move(cb)} {}
virtual ~Callable() {}
virtual void call() = 0;
};

void callOriginalIfNotCalled() {
class CallableFunction : public Callable {
public:
explicit CallableFunction(std::function<void()> cb) : cb_{std::move(cb)} {}

void call() override {
if (cb_) {
// Could be null
cb_();
cb_ = nullptr;
}
}

~EnsureCallOnlyOnce() {
callOriginalIfNotCalled();
}

private:
std::function<void()> cb_;
};

class CallWhenOneInstanceLeft {
// This class will ensure that the contained callback is only called once.
class CallOnce : public Callable {
public:
CallWhenOneInstanceLeft(
std::shared_ptr<EnsureCallOnlyOnce> callWrapper,
std::shared_ptr<size_t> factoryExists)
: factoryExists_{std::move(factoryExists)},
callOnce_{std::move(callWrapper)} {}

CallWhenOneInstanceLeft(const CallWhenOneInstanceLeft& other)
: factoryExists_{other.factoryExists_}, callOnce_{other.callOnce_} {}

CallWhenOneInstanceLeft(CallWhenOneInstanceLeft&& other) noexcept {
factoryExists_ = std::move(other.factoryExists_);
callOnce_ = std::move(other.callOnce_);
explicit CallOnce(std::shared_ptr<Callable> cb) : cb_{std::move(cb)} {}

CallOnce(const CallOnce& other) = delete;
CallOnce(CallOnce&& other) = delete;
CallOnce& operator=(const CallOnce& other) = delete;
CallOnce& operator=(CallOnce&& other) noexcept = delete;

void call() override {
folly::call_once(called_, [&]() { cb_->call(); });
}

void operator()() {
if (!callOnce_) {
return;
private:
std::shared_ptr<Callable> cb_;
folly::once_flag called_;
};

// This class will ensure that only the call from the last caller will go
// through.
class CallOnCountZero : public Callable {
public:
CallOnCountZero(
std::shared_ptr<std::atomic_size_t> callsLeft,
std::shared_ptr<Callable> cb)
: callsLeft_{std::move(callsLeft)}, cb_{std::move(cb)} {}

CallOnCountZero(const CallOnCountZero& other) = delete;
CallOnCountZero(CallOnCountZero&& other) = delete;
CallOnCountZero& operator=(const CallOnCountZero& other) = delete;
CallOnCountZero& operator=(CallOnCountZero&& other) noexcept = delete;

void call() override {
if (*callsLeft_ > 0) {
--*(callsLeft_);
}
// If this is the last callback created out callOnce_, call the
// original callback.
if (callOnce_.use_count() <= (1 + *factoryExists_)) {
callOnce_->callOriginalIfNotCalled();
if (*callsLeft_ == 0) {
cb_->call();
}
callOnce_ = nullptr;
}

~CallWhenOneInstanceLeft() {
(*this)();
private:
std::shared_ptr<std::atomic_size_t> callsLeft_;
std::shared_ptr<Callable> cb_;
};

// This class will ensure that the contained callback is called when the
// operator() is invoked, or when the object is destructed, whatever comes
// first.
class EnsureCall : public Callable {
public:
explicit EnsureCall(std::shared_ptr<Callable> cb)
: cb_{std::make_shared<CallOnce>(std::move(cb))} {}

EnsureCall(const EnsureCall& other) = delete;
EnsureCall(EnsureCall&& other) = delete;
EnsureCall& operator=(const EnsureCall& other) = delete;
EnsureCall& operator=(EnsureCall&& other) noexcept = delete;

~EnsureCall() override {
cb_->call();
}

void call() override {
cb_->call();
}

private:
std::shared_ptr<size_t> factoryExists_;
std::shared_ptr<EnsureCallOnlyOnce> callOnce_;
std::shared_ptr<Callable> cb_;
};

public:
CallbackOnLastSignal(std::function<void()> cb)
: factoryExists_{std::make_shared<size_t>(1)},
callWrapper_{std::make_shared<EnsureCallOnlyOnce>(std::move(cb))} {}
class CountCaller {
public:
CountCaller(
std::shared_ptr<Callable> cb,
std::shared_ptr<std::atomic_size_t> callsLeft)
: cb_{std::move(cb)} {
++(*callsLeft);
}

~CallbackOnLastSignal() {
*factoryExists_ = 0;
}
void operator()() {
cb_->call();
}

private:
std::shared_ptr<Callable> cb_;
};

public:
explicit CallbackOnLastSignal(std::function<void()> cb)
: callsLeft_{std::make_shared<std::atomic_size_t>(0)},
cb_{cb ? std::make_shared<CallOnCountZero>(
callsLeft_,
std::make_shared<CallOnce>(std::make_shared<EnsureCall>(
std::make_shared<CallableFunction>(std::move(cb)))))
: nullptr} {}

std::function<void()> getCallback() const {
return CallWhenOneInstanceLeft{callWrapper_, factoryExists_};
if (!cb_) {
return nullptr;
}
return CountCaller{
std::make_shared<CallOnce>(std::make_shared<EnsureCall>(cb_)),
callsLeft_};
}

private:
std::shared_ptr<size_t> factoryExists_;
std::shared_ptr<EnsureCallOnlyOnce> callWrapper_;
std::shared_ptr<std::atomic_size_t> callsLeft_;
std::shared_ptr<Callable> cb_;
};

} // namespace facebook::velox::dwio::common::unit_loader_tools
6 changes: 6 additions & 0 deletions velox/dwio/common/tests/UnitLoaderToolsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ TEST(UnitLoaderToolsTests, NoCallbacksCreated) {
EXPECT_EQ(callCount, 1);
}

TEST(UnitLoaderToolsTests, SupportsNullCallbacks) {
CallbackOnLastSignal callback(nullptr);
auto cb = callback.getCallback();
EXPECT_TRUE(cb == nullptr);
}

TEST(UnitLoaderToolsTests, NoExplicitCalls) {
std::atomic_size_t callCount = 0;
{
Expand Down

0 comments on commit 95be9d8

Please sign in to comment.