Skip to content

Commit

Permalink
Wrap op erasure in check to catch invalid case (#376)
Browse files Browse the repository at this point in the history
Replaces compiler assertion failure with MLIR error emitted and signalled failure. Follow-up work required.
  • Loading branch information
newling authored Jan 5, 2024
1 parent 3ae51ef commit 763c580
Showing 1 changed file with 46 additions and 12 deletions.
58 changes: 46 additions & 12 deletions mlir/lib/Transform/AIRDependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,27 @@ using namespace xilinx;

namespace {

// Remove an op if it has no users, else return failure.
// This is a temporary measure, while the issue
// https://github.com/Xilinx/mlir-air/issues/372
// is open. Once the root cause is found, there should be no ops erased here
// whose results have users.
LogicalResult eraseOpWithCheck(Operation *op, std::string_view context = "") {
for (auto opResult : op->getResults()) {
for (auto &&user : opResult.getUsers()) {
auto result =
op->emitOpError("is being erased, but it has at least one user.");
result.attachNote(user->getLoc()) << "erased op has user:\n" << *user;
result.attachNote(op->getLoc())
<< "additional context:'" << context << "'\n";
return result;
}
}

op->erase();
return success();
}

// Construction of a dependency graph

struct executeNode {
Expand Down Expand Up @@ -609,7 +630,10 @@ class AIRDependency
region_to_g[async_region.getId()] = v;

// Erase op
op->erase();
if (eraseOpWithCheck(op, "createAsyncExecute (no SSA return)").failed()) {
signalPassFailure();
}

return async_region;
}

Expand Down Expand Up @@ -651,7 +675,9 @@ class AIRDependency
region_to_g[async_region.getId()] = v;

// Erase op
op->erase();
if (eraseOpWithCheck(op, "createAsyncExecute (one SSA return)").failed()) {
signalPassFailure();
}
return async_region;
}

Expand Down Expand Up @@ -683,7 +709,9 @@ class AIRDependency
dma_to_g[id] = v;

// Erase op
op->erase();
if (eraseOpWithCheck(op, "createAsyncDMA").failed()) {
signalPassFailure();
}
}

// Re-instantiate the channel op with async interface; update graph
Expand Down Expand Up @@ -730,7 +758,9 @@ class AIRDependency
channel_to_g[ChannelOpID] = v;

// Erase op
op->erase();
if (eraseOpWithCheck(op, "createAsyncChannel").failed()) {
signalPassFailure();
}
}

// Re-instantiate the hierarchy op with async interface; update graph
Expand Down Expand Up @@ -805,7 +835,9 @@ class AIRDependency
auto new_hier = dyn_cast<air::HierarchyInterface>(new_op);

// Erase op
op->erase();
if (eraseOpWithCheck(op, "createAsyncHierarchyImpls").failed()) {
signalPassFailure();
}
return new_hier;
}

Expand Down Expand Up @@ -1530,8 +1562,9 @@ class AIRDependency
elevateAsyncTokens<scf::ForOp, scf::ParallelOp>(new_loop_op,
wait_all_op_yielded_v);

loop_op.erase();

if (eraseOpWithCheck(loop_op, "insertLoopCarriedDeps").failed()) {
signalPassFailure();
}
loop_op = new_loop_op;
}

Expand All @@ -1549,8 +1582,8 @@ class AIRDependency
// Update op-to-graph map for wait_all ops
wa_to_g[wait_all_op_yielded.getId()] = wait_all_op_yielded_v;

// (2) Create a new wait_all event before the parallel op which collects the
// incoming deps.
// (2) Create a new wait_all event before the parallel op which collects
// the incoming deps.
SmallVector<Value, 4> incoming_tokens;
SmallVector<Value, 4> constants;
llvm::SetVector<Value> region_args;
Expand Down Expand Up @@ -1583,7 +1616,8 @@ class AIRDependency
// Remove the old scf::YieldOp
SmallVector<scf::YieldOp, 2> y_ops(new_loop_op.getOps<scf::YieldOp>());
for (auto y_op : y_ops)
y_op.erase();
if (eraseOpWithCheck(y_op, "insertLoopCarriedDeps").failed())
signalPassFailure();

// Create scf::ReduceOp
builder.setInsertionPointToEnd(new_loop_op.getBody());
Expand All @@ -1602,8 +1636,8 @@ class AIRDependency
elevateAsyncTokens<scf::ParallelOp, scf::ParallelOp>(new_loop_op,
wait_all_op_yielded_v);

loop_op.erase();

if (eraseOpWithCheck(loop_op, "insertLoopCarriedDeps 2").failed())
signalPassFailure();
loop_op = new_loop_op;
}

Expand Down

0 comments on commit 763c580

Please sign in to comment.