Skip to content

Commit

Permalink
Fix failures in write-after-write-elimination (#2591)
Browse files Browse the repository at this point in the history
* Fix failures in write-after-write-elimination

Signed-off-by: Anna Gringauze <[email protected]>

* Fix issues in lif-array-alloc

---------

Signed-off-by: Anna Gringauze <[email protected]>
  • Loading branch information
annagrin authored Feb 7, 2025
1 parent 8e54384 commit ab2d01d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 27 deletions.
46 changes: 28 additions & 18 deletions lib/Optimizer/Transforms/WriteAfterWriteElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@ class SimplifyWritesAnalysis {
for (const auto &[ptr, stores] : ptrToStores) {
if (stores.size() > 1) {
auto replacement = stores.back();
for (auto it = stores.rend(); it != stores.rbegin(); it++) {
auto store = *it;
if (isReplacement(ptr, *store, *replacement)) {
LLVM_DEBUG(llvm::dbgs() << "replacing store " << store
<< " by: " << replacement << '\n');
toErase.push_back(store->getOperation());
for (auto *store : stores) {
if (isReplacement(ptr, store, replacement)) {
LLVM_DEBUG(llvm::dbgs() << "replacing store " << *store
<< " by: " << *replacement << '\n');
toErase.push_back(store);
}
}
}
Expand All @@ -72,13 +71,18 @@ class SimplifyWritesAnalysis {

private:
/// Detect if value is used in the op or its nested blocks.
bool isReplacement(Value ptr, cudaq::cc::StoreOp store,
cudaq::cc::StoreOp replacement) const {
// Check that there are no stores dominated by the store and not dominated
// by the replacement (i.e. used in between the store and the replacement)
for (auto *user : ptr.getUsers()) {
bool isReplacement(Operation *ptr, Operation *store,
Operation *replacement) const {
if (store == replacement)
return false;

// Check that there are no non-store uses dominated by the store and
// not dominated by the replacement, i.e. only uses between the two
// stores are other stores to the same pointer.
for (auto *user : ptr->getUsers()) {
if (user != store && user != replacement) {
if (dom.dominates(store, user) && !dom.dominates(replacement, user)) {
if (!isStoreToPtr(user, ptr) && dom.dominates(store, user) &&
!dom.dominates(replacement, user)) {
LLVM_DEBUG(llvm::dbgs() << "store " << replacement
<< " is used before: " << store << '\n');
return false;
Expand All @@ -88,6 +92,13 @@ class SimplifyWritesAnalysis {
return true;
}

/// Detects a store to the pointer.
static bool isStoreToPtr(Operation *op, Operation *ptr) {
return isa_and_present<cudaq::cc::StoreOp>(op) &&
(dyn_cast<cudaq::cc::StoreOp>(op).getPtrvalue().getDefiningOp() ==
ptr);
}

/// Collect all stores to a pointer for a block.
void collectBlockInfo(Block *block) {
for (auto &op : *block) {
Expand All @@ -96,11 +107,11 @@ class SimplifyWritesAnalysis {
collectBlockInfo(&b);

if (auto store = dyn_cast<cudaq::cc::StoreOp>(&op)) {
auto ptr = store.getPtrvalue();
auto ptr = store.getPtrvalue().getDefiningOp();
if (isStoreToStack(store)) {
auto ptrToStores = blockInfo.FindAndConstruct(block).second;
auto stores = ptrToStores.FindAndConstruct(ptr).second;
stores.push_back(&store);
auto &[b, ptrToStores] = blockInfo.FindAndConstruct(block);
auto &[p, stores] = ptrToStores.FindAndConstruct(ptr);
stores.push_back(&op);
}
}
}
Expand Down Expand Up @@ -128,8 +139,7 @@ class SimplifyWritesAnalysis {
}

DominanceInfo &dom;
DenseMap<Block *, DenseMap<Value, SmallVector<cudaq::cc::StoreOp *>>>
blockInfo;
DenseMap<Block *, DenseMap<Operation *, SmallVector<Operation *>>> blockInfo;
};

class WriteAfterWriteEliminationPass
Expand Down
22 changes: 13 additions & 9 deletions test/Quake/write_after_write_elimination.qke
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// RUN: cudaq-opt -write-after-write-elimination %s | FileCheck %s


func.func @test_two_stores_same_pointer() {
%c0_i64 = arith.constant 0 : i64
%0 = quake.alloca !quake.veq<2>
Expand Down Expand Up @@ -63,23 +64,26 @@ func.func @test_two_stores_different_pointers() {
func.func @test_two_stores_same_pointer_interleaving() {
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%1 = cc.alloca !cc.array<i64 x 2>
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
%c2_i64 = arith.constant 2 : i64
%0 = cc.alloca !cc.array<i64 x 2>
%1 = cc.cast %0 : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
cc.store %c0_i64, %1 : !cc.ptr<i64>
%2 = cc.compute_ptr %0[1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
cc.store %c2_i64, %1 : !cc.ptr<i64>
cc.store %c0_i64, %2 : !cc.ptr<i64>
%3 = cc.compute_ptr %1[1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
cc.store %c0_i64, %3 : !cc.ptr<i64>
cc.store %c1_i64, %1 : !cc.ptr<i64>
cc.store %c1_i64, %2 : !cc.ptr<i64>
cc.store %c1_i64, %3 : !cc.ptr<i64>
return
}

// CHECK-LABEL: func.func @test_two_stores_same_pointer_interleaving() {
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_c0:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_c1:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_c2:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_1:.*]] = cc.alloca !cc.array<i64 x 2>
// CHECK: %[[VAL_2:.*]] = cc.cast %[[VAL_1]] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
// CHECK: %[[VAL_3:.*]] = cc.compute_ptr %[[VAL_1]][1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
// CHECK: cc.store %[[VAL_0]], %[[VAL_2]] : !cc.ptr<i64>
// CHECK: cc.store %[[VAL_0]], %[[VAL_3]] : !cc.ptr<i64>
// CHECK: cc.store %[[VAL_c1]], %[[VAL_2]] : !cc.ptr<i64>
// CHECK: cc.store %[[VAL_c1]], %[[VAL_3]] : !cc.ptr<i64>
// CHECK: return
// CHECK: }

0 comments on commit ab2d01d

Please sign in to comment.