Skip to content

Commit

Permalink
[ControlCodeToTransaction] Use air-rt's serializer to generate transa…
Browse files Browse the repository at this point in the history
…ction (#1002)
  • Loading branch information
Yu-Zhewen authored Dec 31, 2024
1 parent 6fae427 commit 155723c
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,18 @@ LogicalResult configureLocksAndBd(Block &block, const TileLoc &tileLoc,
assert(bdOp.getBdId().has_value() &&
"DMABDOp must have assigned bd_id; did you forget to run "
"aie-assign-bd-ids?");
bool validBd = true;
std::optional<uint8_t> packetType;
std::optional<uint8_t> packetID;
bool enablePacket = false;
auto maybePacketOps = block.getOps<DMABDPACKETOp>();
if (!maybePacketOps.empty()) {
assert(llvm::range_size(maybePacketOps) == 1 &&
"expected only one dma_bd_packet");
auto packetOp = *maybePacketOps.begin();
packetType = packetOp.getPacketType();
packetID = packetOp.getPacketId();
enablePacket = true;
}

BufferOp bufferOp = cast<BufferOp>(bdOp.getBuffer().getDefiningOp());
Expand All @@ -148,16 +151,20 @@ LogicalResult configureLocksAndBd(Block &block, const TileLoc &tileLoc,
BDPadLayout{dim.getConstPadBefore(), dim.getConstPadAfter()});
}
}
if (failed(configureDMABD(deviceModel, dmaTileBd.value(), tileLoc,
static_cast<uint8_t>(*bdOp.getBdId()),
bdOp.getNextBdId().has_value()
? std::optional<uint8_t>{static_cast<uint8_t>(
*bdOp.getNextBdId())}
: std::nullopt,
packetType, packetID, *bufferOp.getAddress(),
getLenInBytes(bdOp), getOffsetInBytes(bdOp),

bool enableNextBd = bdOp.getNextBdId().has_value();
std::optional<uint8_t> nextBdId =
enableNextBd
? std::optional<uint8_t>{static_cast<uint8_t>(*bdOp.getNextBdId())}
: std::nullopt;
std::optional<BDIterLayout> maybeIter = std::nullopt;
if (failed(configureDMABD(deviceModel, dmaTileBd.value(), tileLoc, validBd,
static_cast<uint8_t>(*bdOp.getBdId()), enableNextBd,
nextBdId, enablePacket, packetType, packetID,
*bufferOp.getAddress(), getLenInBytes(bdOp),
getOffsetInBytes(bdOp),
getBufferElementTypeWidthInBytes(bdOp), maybeDims,
maybePadDims))) {
maybePadDims, maybeIter))) {
return failure();
}
return success();
Expand Down Expand Up @@ -233,11 +240,12 @@ LogicalResult addInitConfigToCDO(const AMDAIEDeviceModel &deviceModel,
for (auto op : block.getOps<DMAStartOp>()) {
DMABDOp bd = *op.getDest()->getOps<DMABDOp>().begin();
int chNum = op.getChannelIndex();
auto channelDir = op.getChannelDir();
if (failed(pushToBdQueueAndEnable(
deviceModel, tileLoc, chNum,
static_cast<DMAChannelDir>(channelDir), bd.getBdId().value(),
op.getRepeatCount())))
auto channelDir = static_cast<DMAChannelDir>(op.getChannelDir());
bool issueToken = tileLoc.row == 0 && channelDir == DMAChannelDir::MM2S;
bool setChannelEnable = true;
if (failed(configurePushToBdQueue(
deviceModel, tileLoc, chNum, channelDir, bd.getBdId().value(),
op.getRepeatCount(), issueToken, setChannelEnable)))
return failure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Transforms.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/iree_aie_configure.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Transforms/DialectConversion.h"

#define DEBUG_TYPE "iree-amdaie-controlcode-to-transaction"

#define TXN_OPC_WRITE 0x0
#define TXN_OPC_BLOCKWRITE 0x1
#define TXN_OPC_TCT 0x80
#define TXN_OPC_DDR_PATCH 0x81
#define TXN_OPC_TCT XAie_TxnOpcode::XAIE_IO_CUSTOM_OP_TCT
#define TXN_OPC_DDR_PATCH XAie_TxnOpcode::XAIE_IO_CUSTOM_OP_DDR_PATCH

namespace mlir::iree_compiler::AMDAIE {

Expand All @@ -29,16 +29,25 @@ class TransactionBuilder {

void clearAndInitialize() {
instructions.clear();
llvm::MutableArrayRef<uint32_t> words = reserveAndGetTail(4);
// setup txn header
words[0] = 0x06030100;
words[1] = 0x00000105;
// Setup txn header.
TRY_XAIE_API_FATAL_ERROR(XAie_StartTransaction, &deviceModel.devInst,
XAIE_TRANSACTION_DISABLE_AUTO_FLUSH);
}

size_t getInstructionSize() const { return instructions.size(); }

ArrayRef<uint32_t> finalizeAndReturnInstructions() {
finalizeHeader();
std::unique_ptr<uint8_t, decltype(&free)> txn_ptr(
XAie_ExportSerializedTransaction(&deviceModel.devInst, 0, 0), &free);
// Extract transaction size.
auto *hdr = reinterpret_cast<XAie_TxnHeader *>(txn_ptr.get());
size_t sizeInBytes = hdr->TxnSize;
size_t instructionCount = sizeInBytes / sizeof(uint32_t);
// Resize instructions and copy data.
instructions.resize(instructionCount);
memcpy(instructions.data(), txn_ptr.get(), sizeInBytes);
// Clear the transaction.
TRY_XAIE_API_FATAL_ERROR(XAie_ClearTransaction, &deviceModel.devInst);
return ArrayRef<uint32_t>(instructions.data(), instructions.size());
}

Expand All @@ -52,139 +61,88 @@ class TransactionBuilder {

LogicalResult appendAddressPatch(uint32_t addr, uint32_t argIdx,
uint32_t offset) {
llvm::MutableArrayRef<uint32_t> words = reserveAndGetTail(12);
words[0] = TXN_OPC_DDR_PATCH;
words[1] = words.size() * sizeof(uint32_t); // Operation Size
std::array<uint32_t, 10> words = {0};

words[6] = addr;
words[4] = addr;
words[5] = 0;
words[6] = argIdx;
words[7] = 0;
words[8] = argIdx;
words[8] = offset;
words[9] = 0;
words[10] = offset;
words[11] = 0;
instructionCounter++;
return success();

uint8_t opCode = static_cast<uint8_t>(TXN_OPC_DDR_PATCH);
uint32_t *data = &words[0];
uint32_t size = words.size() * sizeof(uint32_t);
return configureCustomTxnOp(deviceModel, opCode, data, size);
}

LogicalResult appendTCTSync(uint32_t col, uint32_t row, uint32_t direction,
uint32_t rowNum, uint32_t colNum,
uint32_t channel) {
llvm::MutableArrayRef<uint32_t> words = reserveAndGetTail(4);
words[0] = TXN_OPC_TCT;
words[1] = words.size() * sizeof(uint32_t); // Operation Size

words[2] |= direction & 0xff;
words[2] |= (row & 0xff) << 8;
words[2] |= (col & 0xff) << 16;

words[3] |= (rowNum & 0xff) << 8;
words[3] |= (colNum & 0xff) << 16;
words[3] |= (channel & 0xff) << 24;
instructionCounter++;
return success();
std::array<uint32_t, 2> words = {0};

words[0] |= direction & 0xff;
words[0] |= (row & 0xff) << 8;
words[0] |= (col & 0xff) << 16;

words[1] |= (rowNum & 0xff) << 8;
words[1] |= (colNum & 0xff) << 16;
words[1] |= (channel & 0xff) << 24;

uint8_t opCode = static_cast<uint8_t>(TXN_OPC_TCT);
uint32_t *data = &words[0];
uint32_t size = words.size() * sizeof(uint32_t);
return configureCustomTxnOp(deviceModel, opCode, data, size);
}

LogicalResult appendPushToQueueOp(uint32_t col, uint32_t row,
AMDAIE::DMAChannelDir direction,
uint32_t channel, uint32_t bdId,
uint32_t repeatCount, bool issueToken) {
uint32_t colShift = deviceModel.getColumnShift();
uint32_t rowShift = deviceModel.getRowShift();
uint32_t addr =
direction == AMDAIE::DMAChannelDir::MM2S ? 0x1D214 : 0x1D204;
if (channel == 1) addr += 0x8;
// TODO(jornt): use aie-rt's transaction serializer instead to avoid these
// indiscrepancies between this file and aie-rt.
addr = ((col & 0xff) << colShift) | ((row & 0xff) << rowShift) |
(addr & 0xFFFFF);
uint32_t value = 0;
value |= bdId & 0xF;
value |= (repeatCount & 0xFF) << 16;
if (issueToken) value |= 0x80000000;
return appendWrite32Op(addr, value);
}

LogicalResult appendWrite32Op(uint32_t addr, uint32_t value) {
llvm::MutableArrayRef<uint32_t> words = reserveAndGetTail(6);
// XAIE_IO_WRITE
words[0] = TXN_OPC_WRITE;
words[1] = 0;
words[2] = addr;
words[3] = 0;
words[4] = value; // Value
words[5] = words.size() * sizeof(uint32_t); // Operation Size
instructionCounter++;
return success();
// Assume channel is enabled by default.
bool setChannelEnable = false;
auto tileLoc = XAie_TileLoc(col, row);
return configurePushToBdQueue(deviceModel, tileLoc, channel, direction,
bdId, repeatCount, issueToken,
setChannelEnable);
}

LogicalResult appendWriteBdOp(
uint32_t bdAddr, uint32_t bufferLength, uint32_t bufferOffset,
bool enablePacket, uint32_t outOfOrderId, uint32_t packetId,
uint32_t packetType, uint32_t d0Size, uint32_t d0Stride, uint32_t d1Size,
uint32_t d1Stride, uint32_t d2Stride, uint32_t iterationCurrent,
uint32_t iterationSize, uint32_t iterationStride, uint32_t nextBd,
bool useNextBd, bool validBd, int32_t lockRelVal, uint32_t lockRelId,
bool lockAcqEnable, int32_t lockAcqVal, uint32_t lockAcqId) {
llvm::MutableArrayRef<uint32_t> words = reserveAndGetTail(12);
words[0] = TXN_OPC_BLOCKWRITE;
words[1] = 0;
// RegOff
words[2] = bdAddr; // ADDR
words[3] = words.size() * sizeof(uint32_t); // Operation Size
// DMA_BDX_0
words[4] = bufferLength;
// DMA_BDX_1
words[5] = bufferOffset;
// DMA_BDX_2
// En Packet , OoO BD ID , Packet ID , Packet Type
words[6] |= ((int)enablePacket & 0x1) << 30;
words[6] |= (outOfOrderId & 0x3f) << 24;
words[6] |= (packetId & 0x1f) << 19;
words[6] |= (packetType & 0x7) << 16;
// DMA_BDX_3
// TODO: Secure Access
words[7] |= (d0Size & 0x3ff) << 20;
words[7] |= d0Stride & 0xfffff;
// DMA_BDX_4
words[8] = 0x80000000; // burst length;
words[8] |= (d1Size & 0x3ff) << 20;
words[8] |= d1Stride & 0xfffff;
// DMA_BDX_5
// TODO: SIMID, AxCache, AXQoS
words[9] = d2Stride & 0xfffff;
// DMA_BDX_6
words[10] |= (iterationCurrent & 0x3f) << 26;
words[10] |= (iterationSize & 0x3f) << 20;
words[10] |= iterationStride & 0xfffff;
// DMA_BDX_7
// TODO: TLAST Suppress
words[11] |= (nextBd & 0xf) << 27;
words[11] |= ((int)useNextBd & 0x1) << 26;
words[11] |= ((int)validBd & 0x1) << 25;
words[11] |= (lockRelVal & 0xef) << 18;
words[11] |= (lockRelId & 0xf) << 13;
words[11] |= ((int)lockAcqEnable & 0x1) << 12;
words[11] |= (lockAcqVal & 0xef) << 5;
words[11] |= lockAcqId & 0xf;
instructionCounter++;
return success();
uint32_t col, uint32_t row, uint32_t bdId, uint32_t bufferLength,
uint32_t bufferOffset, bool enablePacket, uint32_t packetId,
uint32_t packetType, ArrayRef<int32_t> sizes, ArrayRef<int32_t> strides,
uint32_t iterationCurrent, uint32_t iterationSize,
uint32_t iterationStride, uint32_t nextBd, bool useNextBd, bool validBd,
int32_t lockRelVal, uint32_t lockRelId, bool lockAcqEnable,
int32_t lockAcqVal, uint32_t lockAcqId) {
// Configure DMA Locks.
auto tileLoc = XAie_TileLoc(col, row);
FailureOr<XAie_DmaDesc> dmaTileBd = initDMADesc(deviceModel, tileLoc);
if (failed(dmaTileBd)) return failure();
if (failed(configureDMALocks(deviceModel, dmaTileBd.value(), tileLoc,
lockAcqVal, lockRelVal, lockAcqId, lockRelId,
lockAcqEnable))) {
return failure();
}
// Configure DMA BD.
uint32_t minStrideBitWidth = deviceModel.getMinStrideBitWidth();
uint32_t bufferElementTypeWidthInBytes = minStrideBitWidth / 8;
uint32_t bufferLengthInBytes = bufferLength * bufferElementTypeWidthInBytes;
std::vector<BDDimLayout> dims = {
{static_cast<uint16_t>(sizes[0]), static_cast<uint32_t>(strides[0])},
{static_cast<uint16_t>(sizes[1]), static_cast<uint32_t>(strides[1])},
{static_cast<uint16_t>(sizes[2]), static_cast<uint32_t>(strides[2])}};
std::optional<std::vector<BDPadLayout>> pads = std::nullopt;
BDIterLayout iter = {iterationStride, static_cast<uint8_t>(iterationSize),
static_cast<uint8_t>(iterationCurrent)};
return configureDMABD(deviceModel, dmaTileBd.value(), tileLoc, validBd,
bdId, useNextBd, nextBd, enablePacket, packetType,
packetId, deviceModel.devInst.BaseAddr,
bufferLengthInBytes, bufferOffset,
bufferElementTypeWidthInBytes, dims, pads, iter);
}

private:
void finalizeHeader() {
// Finalize txn header.
instructions[2] = instructionCounter;
instructions[3] = instructions.size() * sizeof(uint32_t);
}

llvm::MutableArrayRef<uint32_t> reserveAndGetTail(size_t tailSize) {
auto oldSize = instructions.size();
auto newSize = oldSize + tailSize;
instructions.resize(newSize, 0);
return llvm::MutableArrayRef<uint32_t>(instructions.data() + oldSize,
tailSize);
}
size_t instructionCounter{0};
std::vector<uint32_t> instructions;
};

Expand Down Expand Up @@ -223,33 +181,20 @@ LogicalResult convertOp(AMDAIE::NpuWriteBdOp op, TransactionBuilder &builder) {
uint32_t col = op.getCol();
uint32_t row = op.getRow();
uint32_t bdId = op.getBdId();
uint32_t colShift = builder.deviceModel.getColumnShift();
uint32_t rowShift = builder.deviceModel.getRowShift();
uint32_t bdAddr =
(col << colShift) | (row << rowShift) | (0x1D000 + bdId * 0x20);
ArrayRef<int32_t> sizes = op.getSizes();
ArrayRef<int32_t> strides = op.getStrides();
SmallVector<int32_t> strides(op.getStrides());
if (sizes.size() != 3) return op.emitOpError() << "expected 3 sizes";
if (strides.size() != 3) return op.emitOpError() << "expected 3 strides";
uint32_t d0Size = sizes[sizes.size() - 1];
uint32_t d1Size = sizes[sizes.size() - 2];
// Strides and iteration_size are encoded as `actual - 1`, but `0` should stay
// `0` as it's not supported;
uint32_t d0Stride =
std::max((int64_t)strides[strides.size() - 1] - 1, (int64_t)0);
uint32_t d1Stride =
std::max((int64_t)strides[strides.size() - 2] - 1, (int64_t)0);
uint32_t d2Stride =
std::max((int64_t)strides[strides.size() - 3] - 1, (int64_t)0);
uint32_t iterationSize =
std::max((int64_t)op.getIterationSize() - 1, (int64_t)0);
uint32_t iterationStride =
std::max((int64_t)op.getIterationStride() - 1, (int64_t)0);
// Strides and iteration_size will be encoded as `actual - 1`, so we need to
// ensure they are at least 1.
std::for_each(strides.begin(), strides.end(),
[](int32_t &stride) { stride = std::max(stride, int32_t(1)); });
uint32_t iterationSize = std::max(op.getIterationSize(), uint32_t(1));
uint32_t iterationStride = std::max(op.getIterationStride(), uint32_t(1));
if (failed(builder.appendWriteBdOp(
bdAddr, op.getBufferLength(), op.getBufferOffset(),
op.getEnablePacket(), op.getOutOfOrderId(), op.getPacketId(),
op.getPacketType(), d0Size, d0Stride, d1Size, d1Stride, d2Stride,
op.getIterationCurrent(), iterationSize, iterationStride,
col, row, bdId, op.getBufferLength(), op.getBufferOffset(),
op.getEnablePacket(), op.getPacketId(), op.getPacketType(), sizes,
strides, op.getIterationCurrent(), iterationSize, iterationStride,
op.getNextBd(), op.getUseNextBd(), op.getValidBd(),
op.getLockRelVal(), op.getLockRelId(), op.getLockAcqEnable(),
op.getLockAcqVal(), op.getLockAcqId()))) {
Expand Down
Loading

0 comments on commit 155723c

Please sign in to comment.