Skip to content

Commit

Permalink
Canonicalize wrap and stride lists, to avoid BD dim limitation when p…
Browse files Browse the repository at this point in the history
…ossible (#389)

* Add wrap and stride canonicalizer

* Clean up

* Add canonicalization patterns which get applied beyond the loop specialization pattern
  • Loading branch information
erwei-xilinx authored Jan 18, 2024
1 parent fde1ed2 commit 7f69c06
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 19 deletions.
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ int64_t air::get1DOffset(SmallVector<Value> memcpy_offsets, Value memref) {
if (memcpy_offsets.empty())
return 0;
SmallVector<int> memref_shape = getTensorShape(memref.getType());
assert(memref_shape.size() == memcpy_offsets.size());

int64_t one_d_offset = 0;
for (int i = memcpy_offsets.size() - 1; i >= 0; i--) {
auto offset = mlir::getConstantIntValue(memcpy_offsets[i]);
Expand Down
146 changes: 146 additions & 0 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,70 @@ struct CanonicalizeAffineApplyOnLoopInductionVar
}
};

LogicalResult AIRCanonicalizeWrapAndStrideList(OpBuilder builder,
SmallVector<Value> &offsets,
SmallVector<Value> &sizes,
SmallVector<Value> &strides) {

// Match offsets size with sizes and strides
int max_dim_size =
std::max(std::max(offsets.size(), sizes.size()), strides.size());
if (max_dim_size && offsets.size() < max_dim_size) {
for (unsigned i = offsets.size(); i < max_dim_size; i++) {
offsets.insert(offsets.begin(), builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), 0));
}
}

SmallVector<int> unit_dims;
for (int i = sizes.size() - 1; i >= 0; i--) {
if (auto const_val = getConstantIntValue(sizes[i])) {
if (*const_val == 1) {
unit_dims.push_back(i);
}
}
}

for (auto i : unit_dims) {
offsets.erase(offsets.begin() + i);
sizes.erase(sizes.begin() + i);
strides.erase(strides.begin() + i);
}

SmallVector<int> redundant_dims;
if (!sizes.empty())
for (int i = sizes.size() - 1; i >= 1; i--) {
if (getConstantIntValue(sizes[i]) && getConstantIntValue(sizes[i - 1]) &&
getConstantIntValue(strides[i]) &&
getConstantIntValue(strides[i - 1])) {
auto const_size = *getConstantIntValue(sizes[i]);
auto const_size_next = *getConstantIntValue(sizes[i - 1]);
auto const_stride = *getConstantIntValue(strides[i]);
auto const_stride_next = *getConstantIntValue(strides[i - 1]);
// Skip over the first dimension if stride is 1
if (const_stride == 1 && i == (int)sizes.size() - 1)
continue;
if (const_stride_next == const_size * const_stride) {
redundant_dims.push_back(i - 1);
sizes[i] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), const_size * const_size_next);
}
}
}

for (auto i : redundant_dims) {
offsets.erase(offsets.begin() + i);
sizes.erase(sizes.begin() + i);
strides.erase(strides.begin() + i);
}

if (unit_dims.empty() && redundant_dims.empty()) {
return failure();
}

return success();
}

struct AIRSpecializeChannelWrapAndStrideInScfFor
: public OpRewritePattern<scf::ForOp> {
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
Expand Down Expand Up @@ -1411,6 +1475,8 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor
rewriter, for_op.getOperation(), channel_ops[0].getOperation(), offsets,
wraps, strides, channel_ops[0].getMemref());

(void)AIRCanonicalizeWrapAndStrideList(rewriter, offsets, wraps, strides);

Operation *new_chan_op = nullptr;
SmallVector<Type, 1> tys;
if (isAsyncOp(channel_ops[0].getOperation())) {
Expand Down Expand Up @@ -1504,6 +1570,8 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor
rewriter, for_op.getOperation(), channel_ops[0].getOperation(), offsets,
wraps, strides, channel_ops[0].getMemref());

(void)AIRCanonicalizeWrapAndStrideList(rewriter, offsets, wraps, strides);

Operation *new_chan_op = nullptr;
SmallVector<Type, 1> tys;
if (isAsyncOp(channel_ops[0].getOperation())) {
Expand Down Expand Up @@ -1536,6 +1604,78 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor
private:
};

struct AIRCanonicalizeChannelPutOpWrapAndStrideList
: public OpRewritePattern<air::ChannelPutOp> {
using OpRewritePattern<air::ChannelPutOp>::OpRewritePattern;

LogicalResult matchAndRewrite(air::ChannelPutOp op,
PatternRewriter &rewriter) const override {

SmallVector<Value, 1> deps;
SmallVector<Type, 1> tys;
if (isAsyncOp(op)) {
tys.push_back(air::AsyncTokenType::get(op->getContext()));
deps = op.getAsyncDependencies();
}

SmallVector<Value> offsets = op.getOffsets();
SmallVector<Value> sizes = op.getSizes();
SmallVector<Value> strides = op.getStrides();

if (failed(AIRCanonicalizeWrapAndStrideList(rewriter, offsets, sizes,
strides)))
return failure();

auto new_op = rewriter.create<air::ChannelPutOp>(
op->getLoc(), tys, deps, op.getChanName(), op.getIndices(),
op.getMemref(), offsets, sizes, strides);
for (unsigned i = 0; i < op->getResults().size(); i++)
op->getResults()[i].replaceAllUsesWith(new_op->getResults()[i]);

rewriter.eraseOp(op);

return success();
}

private:
};

struct AIRCanonicalizeChannelGetOpWrapAndStrideList
: public OpRewritePattern<air::ChannelGetOp> {
using OpRewritePattern<air::ChannelGetOp>::OpRewritePattern;

LogicalResult matchAndRewrite(air::ChannelGetOp op,
PatternRewriter &rewriter) const override {

SmallVector<Value, 1> deps;
SmallVector<Type, 1> tys;
if (isAsyncOp(op)) {
tys.push_back(air::AsyncTokenType::get(op->getContext()));
deps = op.getAsyncDependencies();
}

SmallVector<Value> offsets = op.getOffsets();
SmallVector<Value> sizes = op.getSizes();
SmallVector<Value> strides = op.getStrides();

if (failed(AIRCanonicalizeWrapAndStrideList(rewriter, offsets, sizes,
strides)))
return failure();

auto new_op = rewriter.create<air::ChannelGetOp>(
op->getLoc(), tys, deps, op.getChanName(), op.getIndices(),
op.getMemref(), offsets, sizes, strides);
for (unsigned i = 0; i < op->getResults().size(); i++)
op->getResults()[i].replaceAllUsesWith(new_op->getResults()[i]);

rewriter.eraseOp(op);

return success();
}

private:
};

struct UnrollChannelByFactorPattern {

public:
Expand Down Expand Up @@ -2261,6 +2401,12 @@ class AIRSpecializeChannelWrapAndStridePattern
AIRSpecializeChannelWrapAndStrideInScfFor,
AIRSpecializeChannelWrapAndStrideInAffineFor>(ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));

// Canonicalize wrap and stride list to remove redundant dimensions
RewritePatternSet cano_patterns(&getContext());
cano_patterns.insert<AIRCanonicalizeChannelGetOpWrapAndStrideList,
AIRCanonicalizeChannelPutOpWrapAndStrideList>(ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(cano_patterns));
}

void runOnOperation() override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ module {
// CHECK-LABEL: test0
// CHECK: air.segment
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[CST64:.*]] = arith.constant 64 : index
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[CST4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[CST16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[CST32:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[CST64:.*]] = arith.constant 64 : index
// CHECK-DAG: %[[CST512:.*]] = arith.constant 512 : index
// CHECK-DAG: %[[CST4:.*]] = arith.constant 4 : index
// CHECK: %[[EVENT0:.*]] = air.channel.put async [{{.*}}] @channel_0[] (%[[VAL0:.*]][%[[CST0]], %[[CST0]]] [%[[CST8]], %[[CST4]], %[[CST4]]] [%[[CST64]], %[[CST16]], %[[CST1]]]) : (memref<8x16xi32, 1>)
// CHECK: %[[EVENT0:.*]] = air.channel.put async [{{.*}}] @channel_0[] (%[[VAL0:.*]][%[[CST0]], %[[CST0]]] [%[[CST32]], %[[CST4]]] [%[[CST16]], %[[CST1]]]) : (memref<8x16xi32, 1>)
// CHECK: scf.for{{.*}}iter_args(%[[EVENT1:.*]] = %[[EVENT0]])
// CHECK: air.herd
// CHECK: air.channel.get async{{.*}}@channel_0
Expand Down Expand Up @@ -66,11 +66,11 @@ module {
}

// CHECK-LABEL: test1
// CHECK: put @channel_1[%c0, %c0] (%arg0[%c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: get @channel_2[%c0, %c0] (%arg1[%c0, %c0] [%c4, %c32, %c32] [%c4096, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_3[%c0, %c0] (%arg1[%c0, %c0] [%c4, %c32, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_4[%c0, %c0] (%arg1[%c0, %c0] [%c4, %c4, %c32, %c32] [%c32, %c4096, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: get @channel_5[%c0, %c0] (%arg1[%c0, %c0] [%c4, %c4, %c32, %c32] [%c4096, %c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_1[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: get @channel_2[%c0, %c0] (%arg1[%c0, %c0] [%c128, %c32] [%c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_3[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c32, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_4[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c128, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: get @channel_5[%c0, %c0] (%arg1[%c0, %c0, %c0, %c0] [%c4, %c4, %c32, %c32] [%c4096, %c32, %c128, %c1]) : (memref<128x128xf32>)

func.func @test1(%arg0: memref<128xf32>, %arg1: memref<128x128xf32>) -> memref<128xf32> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -101,11 +101,11 @@ module {
}

// CHECK-LABEL: test2
// CHECK: put @channel_6[%c0, %c0] (%arg0[%c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: get @channel_7[%c0, %c0] (%arg1[%c0, %c0] [%c4, %c32, %c32] [%c4096, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_8[%c0, %c0] (%arg1[%c0, %c0] [%c4, %c32, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_9[%c0, %c0] (%arg1[%c0, %c0] [%c4, %c4, %c32, %c32] [%c32, %c4096, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: get @channel_10[%c0, %c0] (%arg1[%c0, %c0] [%c4, %c4, %c32, %c32] [%c4096, %c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_6[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: get @channel_7[%c0, %c0] (%arg1[%c0, %c0] [%c128, %c32] [%c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_8[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c32, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_9[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c128, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: get @channel_10[%c0, %c0] (%arg1[%c0, %c0, %c0, %c0] [%c4, %c4, %c32, %c32] [%c4096, %c32, %c128, %c1]) : (memref<128x128xf32>)

func.func @test2(%arg0: memref<128xf32>, %arg1: memref<128x128xf32>) -> memref<128xf32> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -136,10 +136,10 @@ module {
}

// CHECK-LABEL: test3
// CHECK: put @channel_11[%c0, %c0] (%arg0[%c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: put @channel_12[%c0, %c0] (%arg1[%c0, %c0] [%c4, %c4, %c32, %c32] [%c32, %c4096, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_13[%c0, %c0] (%arg0[%c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: put async [%0] @channel_14[%c0, %c0] (%arg0[%c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: put @channel_11[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: put @channel_12[%c0, %c0] (%arg1[%c0, %c0, %c0] [%c4, %c128, %c32] [%c32, %c128, %c1]) : (memref<128x128xf32>)
// CHECK: put @channel_13[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)
// CHECK: put async [%0] @channel_14[%c0, %c0] (%arg0[%c0, %c0] [%c4, %c32] [%c32, %c1]) : (memref<128xf32>)

func.func @test3(%arg0: memref<128xf32>, %arg1: memref<128x128xf32>) -> memref<128xf32> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -174,4 +174,24 @@ module {
}
return %alloc : memref<128xf32>
}

// CHECK-LABEL: test4
// CHECK: put async @channel_15[%c0, %c0] (%arg0[%c0] [%c32] [%c1]) : (memref<128xf32>)
// CHECK: put async @channel_16[%c0, %c0] (%arg1[%c0, %c0] [%c16, %c4] [%c4, %c1]) : (memref<128x128xf32>)

func.func @test4(%arg0: memref<128xf32>, %arg1: memref<128x128xf32>) -> memref<128xf32> {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c8 = arith.constant 8 : index
%alloc = memref.alloc() : memref<128xf32>
%0 = air.channel.put async @channel_15[%c0, %c0] (%arg0[%c0, %c0, %c0, %c0] [%c1, %c1, %c1, %c32] [%c128, %c128, %c128, %c1]) : (memref<128xf32>)
%1 = air.channel.put async @channel_16[%c0, %c0] (%arg1[%c0, %c0, %c0, %c0] [%c1, %c2, %c8, %c4] [%c64, %c32, %c4, %c1]) : (memref<128x128xf32>)
%2 = air.wait_all async [%0, %1]
return %alloc : memref<128xf32>
}
}

0 comments on commit 7f69c06

Please sign in to comment.