Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTG][Elaboration] Add support for 'index.add' and 'index.cmp' #7978

Open
wants to merge 1 commit into
base: maerhart-rtg-elaboration-sequences
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/circt/Dialect/RTG/Transforms/RTGPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def ElaborationPass : Pass<"rtg-elaborate", "mlir::ModuleOp"> {
"The seed for any RNG constructs used in the pass.">,
];

let dependentDialects = ["mlir::arith::ArithDialect"];
let dependentDialects = ["mlir::index::IndexDialect"];
}

#endif // CIRCT_DIALECT_RTG_TRANSFORMS_RTGPASSES_TD
2 changes: 1 addition & 1 deletion lib/Dialect/RTG/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ add_circt_dialect_library(CIRCTRTGTransforms

LINK_LIBS PRIVATE
CIRCTRTGDialect
MLIRArithDialect
MLIRIndexDialect
MLIRIR
MLIRPass
)
Expand Down
161 changes: 145 additions & 16 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#include "circt/Dialect/RTG/IR/RTGVisitors.h"
#include "circt/Dialect/RTG/Transforms/RTGPasses.h"
#include "circt/Support/Namespace.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -85,7 +86,7 @@ namespace {
/// The abstract base class for elaborated values.
struct ElaboratorValue {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like just a std::variant.

public:
enum class ValueKind { Attribute, Set, Bag, Sequence };
enum class ValueKind { Attribute, Set, Bag, Sequence, Index, Bool };

ElaboratorValue(ValueKind kind) : kind(kind) {}
virtual ~ElaboratorValue() {}
Expand Down Expand Up @@ -144,6 +145,74 @@ class AttributeValue : public ElaboratorValue {
const TypedAttr attr;
};

/// Holds an evaluated value of a `IndexType`'d value.
class IndexValue : public ElaboratorValue {
public:
IndexValue(size_t index) : ElaboratorValue(ValueKind::Index), index(index) {}

// Implement LLVMs RTTI
static bool classof(const ElaboratorValue *val) {
return val->getKind() == ValueKind::Index;
}

llvm::hash_code getHashValue() const override {
return llvm::hash_value(index);
}

bool isEqual(const ElaboratorValue &other) const override {
auto *indexValue = dyn_cast<IndexValue>(&other);
if (!indexValue)
return false;

return index == indexValue->index;
}

#ifndef NDEBUG
void print(llvm::raw_ostream &os) const override {
os << "<index " << index << " at " << this << ">";
}
#endif

size_t getIndex() const { return index; }

private:
const size_t index;
};

/// Holds an evaluated value of an `i1` type'd value.
class BoolValue : public ElaboratorValue {
public:
BoolValue(bool value) : ElaboratorValue(ValueKind::Bool), value(value) {}

// Implement LLVMs RTTI
static bool classof(const ElaboratorValue *val) {
return val->getKind() == ValueKind::Bool;
}

llvm::hash_code getHashValue() const override {
return llvm::hash_value(value);
}

bool isEqual(const ElaboratorValue &other) const override {
auto *val = dyn_cast<BoolValue>(&other);
if (!val)
return false;

return value == val->value;
}

#ifndef NDEBUG
void print(llvm::raw_ostream &os) const override {
os << "<bool " << (value ? "true" : "false") << " at " << this << ">";
}
#endif

bool getBool() const { return value; }

private:
const bool value;
};

/// Holds an evaluated value of a `SetType`'d value.
class SetValue : public ElaboratorValue {
public:
Expand Down Expand Up @@ -366,7 +435,8 @@ class Materializer {

OpBuilder builder(block, insertionPoint);
return TypeSwitch<ElaboratorValue *, Value>(val)
.Case<AttributeValue, SetValue, BagValue, SequenceValue>([&](auto val) {
.Case<AttributeValue, IndexValue, BoolValue, SetValue, BagValue,
SequenceValue>([&](auto val) {
return visit(val, builder, loc, elabRequests, emitError);
})
.Default([](auto val) {
Expand All @@ -389,13 +459,11 @@ class Materializer {
function_ref<InFlightDiagnostic()> emitError) {
auto attr = val->getAttr();

// For integer attributes (and arithmetic operations on them) we use the
// arith dialect.
if (isa<IntegerAttr>(attr)) {
Value res = builder.getContext()
->getLoadedDialect<arith::ArithDialect>()
->materializeConstant(builder, attr, attr.getType(), loc)
->getResult(0);
// For index attributes (and arithmetic operations on them) we use the
// index dialect.
if (auto intAttr = dyn_cast<IntegerAttr>(attr);
intAttr && isa<IndexType>(attr.getType())) {
Value res = builder.create<index::ConstantOp>(loc, intAttr);
materializedValues[val] = res;
return res;
}
Expand All @@ -417,6 +485,22 @@ class Materializer {
return res;
}

Value visit(IndexValue *val, OpBuilder &builder, Location loc,
std::queue<SequenceValue *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
Value res = builder.create<index::ConstantOp>(loc, val->getIndex());
materializedValues[val] = res;
return res;
}

Value visit(BoolValue *val, OpBuilder &builder, Location loc,
std::queue<SequenceValue *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
Value res = builder.create<index::BoolConstantOp>(loc, val->getBool());
materializedValues[val] = res;
return res;
}

Value visit(SetValue *val, OpBuilder &builder, Location loc,
std::queue<SequenceValue *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
Expand Down Expand Up @@ -451,7 +535,7 @@ class Materializer {
if (iter != integerValues.end()) {
materializedWeight = iter->second;
} else {
materializedWeight = builder.create<arith::ConstantOp>(
materializedWeight = builder.create<index::ConstantOp>(
loc, builder.getIndexAttr(weight));
integerValues[weight] = materializedWeight;
}
Expand Down Expand Up @@ -606,9 +690,8 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
// If the multiple is not stored as an AttributeValue, the elaboration
// must have already failed earlier (since we don't have
// unevaluated/opaque values).
auto *interpMultiple = cast<AttributeValue>(state.at(multiple));
uint64_t m = cast<IntegerAttr>(interpMultiple->getAttr()).getInt();
bag[interpValue] += m;
auto *interpMultiple = cast<IndexValue>(state.at(multiple));
bag[interpValue] += interpMultiple->getIndex();
}

internalizeResult<BagValue>(op.getBag(), std::move(bag), op.getType());
Expand Down Expand Up @@ -688,6 +771,43 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(index::AddOp op) {
size_t lhs = cast<IndexValue>(state.at(op.getLhs()))->getIndex();
size_t rhs = cast<IndexValue>(state.at(op.getRhs()))->getIndex();
internalizeResult<IndexValue>(op.getResult(), lhs + rhs);
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(index::CmpOp op) {
size_t lhs = cast<IndexValue>(state.at(op.getLhs()))->getIndex();
size_t rhs = cast<IndexValue>(state.at(op.getRhs()))->getIndex();
bool result;
switch (op.getPred()) {
case index::IndexCmpPredicate::EQ:
result = lhs == rhs;
break;
case index::IndexCmpPredicate::NE:
result = lhs != rhs;
break;
case index::IndexCmpPredicate::ULT:
result = lhs < rhs;
break;
case index::IndexCmpPredicate::ULE:
result = lhs <= rhs;
break;
case index::IndexCmpPredicate::UGT:
result = lhs > rhs;
break;
case index::IndexCmpPredicate::UGE:
result = lhs >= rhs;
break;
default:
return op->emitOpError("elaboration not supported");
}
internalizeResult<BoolValue>(op.getResult(), result);
return DeletionKind::Delete;
}

FailureOr<DeletionKind> dispatchOpVisitor(Operation *op) {
if (op->hasTrait<OpTrait::ConstantLike>()) {
SmallVector<OpFoldResult, 1> result;
Expand All @@ -700,11 +820,20 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
return op->emitError(
"only typed attributes supported for constant-like operations");

internalizeResult<AttributeValue>(op->getResult(0), attr);
auto intAttr = dyn_cast<IntegerAttr>(attr);
if (intAttr && isa<IndexType>(attr.getType()))
internalizeResult<IndexValue>(op->getResult(0), intAttr.getInt());
else if (intAttr && intAttr.getType().isSignlessInteger(1))
internalizeResult<BoolValue>(op->getResult(0), intAttr.getInt());
else
internalizeResult<AttributeValue>(op->getResult(0), attr);

return DeletionKind::Delete;
}

return RTGBase::dispatchOpVisitor(op);
return TypeSwitch<Operation *, FailureOr<DeletionKind>>(op)
.Case<index::AddOp, index::CmpOp>([&](auto op) { return visitOp(op); })
.Default([&](Operation *op) { return RTGBase::dispatchOpVisitor(op); });
}

LogicalResult elaborate(SequenceOp family, SequenceOp dest,
Expand Down
Loading
Loading