Skip to content

Commit

Permalink
[SYCLomatic] Enable migration of asm instruction 'lop3'. (#335)
Browse files Browse the repository at this point in the history
Signed-off-by: Tang, Jiajun [email protected]
  • Loading branch information
tangjj11 authored Nov 24, 2022
1 parent e9184b5 commit 594237d
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 2 deletions.
87 changes: 85 additions & 2 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14808,9 +14808,92 @@ void AsmRule::registerMatcher(ast_matchers::MatchFinder &MF) {
this);
}

std::string getAsmLop3Expr(const llvm::SmallVector<std::string, 5> &Operands) {
if (Operands.size() != 5) {
return "";
}
const auto &d = Operands[0];
const auto &a = Operands[1];
const auto &b = Operands[2];
const auto &c = Operands[3];
auto imm = std::stoi(Operands[4], 0, 16);
if (imm == 0x00) {
return buildString(d, " = 0");
}
std::ostringstream OS;
OS << d << " =";
if (imm & 0x01)
OS << " (~" << a << " & ~" << b << " & ~" << c << ") |";
if (imm & 0x02)
OS << " (~" << a << " & ~" << b << " & " << c << ") |";
if (imm & 0x04)
OS << " (~" << a << " & " << b << " & ~" << c << ") |";
if (imm & 0x08)
OS << " (~" << a << " & " << b << " & " << c << ") |";
if (imm & 0x10)
OS << " (" << a << " & ~" << b << " & ~" << c << ") |";
if (imm & 0x20)
OS << " (" << a << " & ~" << b << " & " << c << ") |";
if (imm & 0x40)
OS << " (" << a << " & " << b << " & ~" << c << ") |";
if (imm & 0x80)
OS << " (" << a << " & " << b << " & " << c << ") |";
auto ret = OS.str();
return ret.replace(ret.length() - 2, 2, "");
}

void AsmRule::runRule(const ast_matchers::MatchFinder::MatchResult &Result) {
if (auto E = getNodeAsType<Stmt>(Result, "asm")) {
report(E->getBeginLoc(), Diagnostics::DEVICE_ASM, true);
if (auto AS = getNodeAsType<AsmStmt>(Result, "asm")) {
auto AsmString = AS->generateAsmString(*Result.Context);
auto TemplateString = StringRef(AsmString).substr(0, AsmString.find(';'));
auto CurrIndex = TemplateString.find(' ');
auto OpCode = TemplateString.substr(0, CurrIndex);
if (OpCode == "lop3.b32") {
// ASM instruction pattern: lop3.b32 d, a, b, c, immLut;
llvm::SmallVector<std::string, 4> Args;
for (const auto *const it : AS->children()) {
ExprAnalysis EA;
EA.analyze(cast<Expr>(it));
if (isa<IntegerLiteral>(it) || isa<DeclRefExpr>(it) ||
isa<ImplicitCastExpr>(it)) {
Args.push_back(EA.getReplacedString());
} else {
Args.push_back("(" + EA.getReplacedString() + ")");
}
}
llvm::SmallVector<std::string, 5> Operands;
auto PreIndex = CurrIndex;
CurrIndex = TemplateString.find(",", PreIndex);
// Clang will generate the ASM instruction into a string like this:
// Cuda code: asm("lop3.b32 %0, %1*%1, %1, 3, 0x1A;" : "=r"(b) : "r"(a));
// TemplateString: "lop3.b32 $0, $1*$1,$ 1, 3, 0x1A"
while (PreIndex != StringRef::npos) {
auto TempStr =
TemplateString.substr(PreIndex + 1, CurrIndex - PreIndex - 1).str();
// Replace all args, example: the "$1*$1" will be replace by "(a*a)".
if (TempStr.find('$') != TempStr.length() - 2 && Operands.size() != 4) {
// When the operands only contain a register, or is the last imm, not
// need add the paren.
TempStr = "(" + TempStr + ")";
}
auto ArgIndex = TempStr.find('$');
while (ArgIndex != std::string::npos) {
// The PTX Instructions has mostly 4 parameters, so just use the char
// after '$'.
auto ArgNo = TempStr[ArgIndex + 1] - '0';
TempStr.replace(ArgIndex, 2, Args[ArgNo]);
ArgIndex = TempStr.find('$');
}
Operands.push_back(std::move(TempStr));
PreIndex = CurrIndex;
CurrIndex = TemplateString.find(",", PreIndex + 1);
}
auto Replacement = getAsmLop3Expr(Operands);
if (!Replacement.empty()) {
return emplaceTransformation(new ReplaceStmt(AS, Replacement));
}
}
report(AS->getAsmLoc(), Diagnostics::DEVICE_ASM, true);
}
return;
}
Expand Down
35 changes: 35 additions & 0 deletions clang/test/dpct/asm_lop3.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: dpct -out-root %T/asm_lop3 %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
// RUN: FileCheck %s --match-full-lines --input-file %T/asm_lop3/asm_lop3.dp.cpp

// a^b^c
static __device__ __forceinline__ uint32_t LOP3LUT_XOR(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d1;
// CHECK: d1 = (~a & ~b & c) | (~a & b & ~c) | (a & ~b & ~c) | (a & b & c);
asm("lop3.b32 %0, %1, %2, %3, 0x96;" : "=r"(d1) : "r"(a), "r"(b), "r"(c));
return d1;
}

// (a ^ (c & (b ^ a)))
static __device__ __forceinline__ uint32_t LOP3LUT_XORAND(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d2;
// CHECK: d2 = (~a & c & b) | (a & ~c & ~b) | (a & ~c & b) | (a & c & b);
asm("lop3.b32 %0, %1, %3, %2, 0xb8;" : "=r"(d2) : "r"(a), "r"(b), "r"(c));
return d2;
}

// ((a & (b | b)) | (b & b))
static __device__ __forceinline__ uint32_t LOP3LUT_ANDOR(uint32_t a, uint32_t b) {
uint32_t d3;
// CHECK: d3 = (~a & b & b) | (a & ~b & b) | (a & b & ~b) | (a & b & b);
asm("lop3.b32 %0, %1, %2, %2, 0xe8;" : "=r"(d3) : "r"(a), "r"(b));
return d3;
}

#define B 3
// (((a + B) * (a + B)) & B | 3) ^ ((a + B) * (a + B)))
__device__ int hard(int a) {
int d4;
// CHECK: d4 = (~((a + B) * (a + B)) & ~B & (3)) | (~((a + B) * (a + B)) & B & (3)) | (((a + B) * (a + B)) & ~B & ~(3));
asm("lop3.b32 %0, %1 * %1, %2, 3, 0x1A;" : "=r"(d4) : "r"(a + B), "r"(B));
return d4;
}

0 comments on commit 594237d

Please sign in to comment.