Skip to content

Commit

Permalink
Remove the requirement for power of twos in EnumBitSet
Browse files Browse the repository at this point in the history
Summary:
The currenty implementation of `sparta::EnumBitSet<Enum>` requires the Enum to have values that are power of 2.
This is not really necessary and makes the enum definition a bit annoying.

Let's update EnumBitSet to accept enums with values from 0 to 63.

Reviewed By: arnaudvenet

Differential Revision: D51667840

fbshipit-source-id: 1b36badbbfff58301680faa789432949026ce062
  • Loading branch information
arthaud authored and facebook-github-bot committed Nov 29, 2023
1 parent e26ca86 commit 3f9d7ae
Showing 1 changed file with 34 additions and 25 deletions.
59 changes: 34 additions & 25 deletions include/sparta/utils/EnumBitSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#pragma once

#include <cstdint>
#include <initializer_list>
#include <type_traits>

Expand All @@ -16,35 +17,53 @@ namespace sparta {
* A set of enum values.
*
* `EnumBitSet<Enum>` can be used to store an OR-combination of enum values,
* where `Enum` is an enum class type. `Enum` underlying values must be a power
* of 2.
* where `Enum` is an enum class type.
*
* `Enum` underlying values must be unsigned integers between 0 and 63. `Enum`
* must have a `_Count` member containing the maximum value.
*/
template <typename Enum>
class EnumBitSet final {
private:
static_assert(std::is_enum_v<Enum>, "Enum must be an enumeration type");
static_assert(
std::is_unsigned_v<std::underlying_type_t<Enum>>,
"The underlying type of Enum must be an unsigned arithmetic type");
static_assert(static_cast<std::underlying_type_t<Enum>>(Enum::_Count) < 64u,
"Enum::_Count must be less than 64");

public:
using EnumType = Enum;
using IntT = std::underlying_type_t<Enum>;
using EnumUnderlyingT = std::underlying_type_t<Enum>;

private:
using IntT = std::conditional_t<
static_cast<EnumUnderlyingT>(Enum::_Count) < 8u,
std::uint8_t,
std::conditional_t<static_cast<EnumUnderlyingT>(Enum::_Count) < 32u,
std::uint32_t,
std::uint64_t>>;

static constexpr IntT enum_to_bit(Enum value) {
return static_cast<IntT>(1)
<< static_cast<IntT>(static_cast<EnumUnderlyingT>(value));
}

public:
EnumBitSet() = default;

/* implicit */ constexpr EnumBitSet(Enum value)
: value_(static_cast<IntT>(value)) {}
: value_(enum_to_bit(value)) {}

/* implicit */ constexpr EnumBitSet(std::initializer_list<Enum> set)
: value_(0) {
: value_(0u) {
for (auto value : set) {
value_ |= static_cast<IntT>(value);
value_ |= enum_to_bit(value);
}
}

constexpr EnumBitSet& operator&=(Enum value) {
value_ &= static_cast<IntT>(value);
value_ &= enum_to_bit(value);
return *this;
}

Expand All @@ -54,7 +73,7 @@ class EnumBitSet final {
}

constexpr EnumBitSet& operator|=(Enum value) {
value_ |= static_cast<IntT>(value);
value_ |= enum_to_bit(value);
return *this;
}

Expand All @@ -64,7 +83,7 @@ class EnumBitSet final {
}

constexpr EnumBitSet& operator^=(Enum value) {
value_ ^= static_cast<IntT>(value);
value_ ^= enum_to_bit(value);
return *this;
}

Expand All @@ -74,23 +93,23 @@ class EnumBitSet final {
}

constexpr EnumBitSet operator&(Enum value) const {
return EnumBitSet(value_ & static_cast<IntT>(value));
return EnumBitSet(value_ & enum_to_bit(value));
}

constexpr EnumBitSet operator&(EnumBitSet set) const {
return EnumBitSet(value_ & set.value_);
}

constexpr EnumBitSet operator|(Enum value) const {
return EnumBitSet(value_ | static_cast<IntT>(value));
return EnumBitSet(value_ | enum_to_bit(value));
}

constexpr EnumBitSet operator|(EnumBitSet set) const {
return EnumBitSet(value_ | set.value_);
}

constexpr EnumBitSet operator^(Enum value) const {
return EnumBitSet(value_ ^ static_cast<IntT>(value));
return EnumBitSet(value_ ^ enum_to_bit(value));
}

constexpr EnumBitSet operator^(EnumBitSet set) const {
Expand All @@ -112,18 +131,14 @@ class EnumBitSet final {
}

constexpr bool test(Enum value) const {
if (static_cast<IntT>(value) == 0) {
return value_ == 0;
} else {
return (value_ & static_cast<IntT>(value)) == static_cast<IntT>(value);
}
return (value_ & enum_to_bit(value)) == enum_to_bit(value);
}

constexpr EnumBitSet& set(Enum value, bool on = true) {
if (on) {
value_ |= static_cast<IntT>(value);
value_ |= enum_to_bit(value);
} else {
value_ &= ~static_cast<IntT>(value);
value_ &= ~enum_to_bit(value);
}
return *this;
}
Expand All @@ -140,12 +155,6 @@ class EnumBitSet final {
return (value_ && !(value_ & (value_ - 1)));
}

constexpr IntT encode() const { return value_; }

static constexpr EnumBitSet decode(IntT encoding) {
return EnumBitSet(encoding);
}

private:
explicit constexpr EnumBitSet(IntT value) : value_(value) {}

Expand Down

0 comments on commit 3f9d7ae

Please sign in to comment.