diff --git a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp index 2604f65a..d081d9cf 100644 --- a/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp +++ b/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp @@ -84,7 +84,7 @@ class LoopTypeConverter : public TypeConverter { // reinterpret_cast. addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType, ValueRange inputs, - Location loc) -> std::optional { + Location loc) -> Value { auto reinterpretCast = inputs[0].getDefiningOp(); if (!reinterpretCast) { diff --git a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp index bcfea253..c84732ed 100644 --- a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp +++ b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -74,7 +74,7 @@ class TritonToStructuredPass RewritePatternSet patterns(&getContext()); auto context = &getContext(); - OneToNTypeConverter converter; + TypeConverter converter; converter.addConversion([](Type type) { return type; }); // We are doing a 1->1 type conversion here, where a triton pointer type @@ -145,10 +145,10 @@ class TritonToStructuredPass // Compute the target materialization, given a value with the pointer type, // convert that value to a tuple type. converter.addTargetMaterialization( - [](OpBuilder &builder, TypeRange resultTypes, Value input, - Location loc) -> std::optional> { + [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, + Location loc) -> SmallVector { return builder - .create(loc, resultTypes, input) + .create(loc, resultTypes, inputs.front()) ->getResults(); }); @@ -172,7 +172,7 @@ class TritonToStructuredPass auto moduleOp = getOperation(); auto context = &getContext(); - OneToNTypeConverter converter; + TypeConverter converter; converter.addConversion([](Type type) { return type; }); // We are doing a 1->N type conversion here, where a pointer tuple type @@ -208,10 +208,10 @@ class TritonToStructuredPass // At the end of pointer analysis, we will use the PtrState to create the // correct offsets, strides, and remove these ops. converter.addTargetMaterialization([](OpBuilder &builder, - TypeRange resultTypes, Value input, + TypeRange resultTypes, ValueRange inputs, Location loc) { auto placeholder = builder.create( - loc, input.getDefiningOp()->getOperand(0)); + loc, inputs.front().getDefiningOp()->getOperand(0)); assert(llvm::equal(placeholder.getResultTypes(), resultTypes)); return placeholder.getResults(); }); diff --git a/triton b/triton index 31040564..94684d32 160000 --- a/triton +++ b/triton @@ -1 +1 @@ -Subproject commit 310405647df51a909943bed71c5a6fd9a3e402b4 +Subproject commit 94684d326723b67b146f23f342623ea058a32098