Skip to content

Commit

Permalink
[ADT] Fix specialization of ValueIsPresent for PointerUnion (#121847)
Browse files Browse the repository at this point in the history
Two instances of `PointerUnion` with different active members and null
value compare unequal. Currently, this results in counterintuitive
behavior when using functions from `Casting.h`, e.g.:

```C++
  PointerUnion<int *, float *> U;
  // U = (int *)nullptr;
  dyn_cast<int *>(U); // Aborts
  dyn_cast<float *>(U); // Aborts
  U = (float *)nullptr;
  dyn_cast<int *>(U); // OK
  dyn_cast<float *>(U); // OK
```

`dyn_cast` should abort in all cases because the argument is null.
Currently, it aborts only if the first member is active. This happens
because the partial template specialization of `ValueIsPresent` for
nullable types compares the union with a union constructed from nullptr,
and the two unions compare equal only if their active members are the
same.

This patch changed the specialization of `ValueIsPresent` for nullable
types to make `isPresent()` return false for all possible null values of
a PointerUnion, and fixes two places where the old behavior was
exploited.

Pull Request: llvm/llvm-project#121847
  • Loading branch information
s-barannikov authored Jan 10, 2025
1 parent 799e988 commit 7b05367
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 8 deletions.
8 changes: 4 additions & 4 deletions llvm/include/llvm/Support/Casting.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,12 +614,12 @@ template <typename T> struct ValueIsPresent<std::optional<T>> {
static inline decltype(auto) unwrapValue(std::optional<T> &t) { return *t; }
};

// If something is "nullable" then we just compare it to nullptr to see if it
// exists.
// If something is "nullable" then we just cast it to bool to see if it exists.
template <typename T>
struct ValueIsPresent<T, std::enable_if_t<IsNullable<T>>> {
struct ValueIsPresent<
T, std::enable_if_t<IsNullable<T> && std::is_constructible_v<bool, T>>> {
using UnwrappedType = T;
static inline bool isPresent(const T &t) { return t != T(nullptr); }
static inline bool isPresent(const T &t) { return static_cast<bool>(t); }
static inline decltype(auto) unwrapValue(T &t) { return t; }
};

Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/CodeGen/RegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ const TargetRegisterClass *RegisterBankInfo::constrainGenericRegister(

// If the register already has a class, fallback to MRI::constrainRegClass.
auto &RegClassOrBank = MRI.getRegClassOrRegBank(Reg);
if (isa<const TargetRegisterClass *>(RegClassOrBank))
if (isa_and_present<const TargetRegisterClass *>(RegClassOrBank))
return MRI.constrainRegClass(Reg, &RC);

const RegisterBank *RB = cast<const RegisterBank *>(RegClassOrBank);
const auto *RB = dyn_cast_if_present<const RegisterBank *>(RegClassOrBank);
// Otherwise, all we can do is ensure the bank covers the class, and set it.
if (RB && !RB->covers(RC))
return nullptr;
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3708,10 +3708,10 @@ const TargetRegisterClass *
SIRegisterInfo::getConstrainedRegClassForOperand(const MachineOperand &MO,
const MachineRegisterInfo &MRI) const {
const RegClassOrRegBank &RCOrRB = MRI.getRegClassOrRegBank(MO.getReg());
if (const RegisterBank *RB = dyn_cast<const RegisterBank *>(RCOrRB))
if (const auto *RB = dyn_cast_if_present<const RegisterBank *>(RCOrRB))
return getRegClassForTypeOnBank(MRI.getType(MO.getReg()), *RB);

if (const auto *RC = dyn_cast<const TargetRegisterClass *>(RCOrRB))
if (const auto *RC = dyn_cast_if_present<const TargetRegisterClass *>(RCOrRB))
return getAllocatableClass(RC);

return nullptr;
Expand Down
5 changes: 5 additions & 0 deletions llvm/unittests/ADT/PointerUnionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ TEST_F(PointerUnionTest, NewCastInfra) {
EXPECT_FALSE(isa<float *>(d4null));
EXPECT_FALSE(isa<long long *>(d4null));

EXPECT_FALSE(isa_and_present<int *>(i4null));
EXPECT_FALSE(isa_and_present<float *>(f4null));
EXPECT_FALSE(isa_and_present<long long *>(l4null));
EXPECT_FALSE(isa_and_present<double *>(d4null));

// test cast<>
EXPECT_EQ(cast<float *>(a), &f);
EXPECT_EQ(cast<int *>(b), &i);
Expand Down

0 comments on commit 7b05367

Please sign in to comment.