Skip to content

Commit

Permalink
Merge branch 'bump_to_285b087a' into bump_to_6382dbbc
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd authored Aug 28, 2024
2 parents ad1facc + a22c27c commit 0ef5530
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 33 deletions.
10 changes: 5 additions & 5 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2574,7 +2574,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value modeStrValue;

auto extract = [&rewriter, &binder](Value x, Value v) {
auto xTy = x.getType().cast<Torch::ValueTensorType>();
auto xTy = cast<Torch::ValueTensorType>(x.getType());
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(xTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();
Expand All @@ -2588,7 +2588,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto sizes =
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
Torch::BaseTensorType operandType =
operand.getType().cast<Torch::BaseTensorType>();
cast<Torch::BaseTensorType>(operand.getType());

SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Expand All @@ -2605,7 +2605,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value item = extract(operand, ext);
itemList.push_back(item);
}
auto xTy = operand.getType().cast<Torch::ValueTensorType>();
auto xTy = cast<Torch::ValueTensorType>(operand.getType());
Value ValueList;
if (isa<IntegerType>(xTy.getDtype())) {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
Expand Down Expand Up @@ -2674,8 +2674,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
scalesValueList = noneVal;
sizesValueList = getValueList(sizeOperand);
}
if (scalesValueList.getType().isa<Torch::NoneType>() &&
sizesValueList.getType().isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
isa<Torch::NoneType>(sizesValueList.getType())) {
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
}
rewriter
Expand Down
50 changes: 24 additions & 26 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
op, "Currently only scalar constants are supported for "
"conversion in TOSA operation");
}
rhsType = rhs.getType().dyn_cast<TensorType>();
rhsType = dyn_cast<TensorType>(rhs.getType());
}

// aten.rsub(lhs, rhs, alpha) computes rhs - lhs * alpha
Expand Down Expand Up @@ -1016,7 +1016,7 @@ LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {

Value exp = adaptor.getExponent();
auto expTy = exp.getType().template dyn_cast<RankedTensorType>();
auto expTy = dyn_cast<RankedTensorType>(exp.getType());

if (!expTy)
return rewriter.notifyMatchFailure(
Expand All @@ -1035,7 +1035,7 @@ LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
"conversion in TOSA Pow operation");

auto outType =
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
cast<TensorType>(getTypeConverter()->convertType(op.getType()));

auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
selfTensor, exp);
Expand Down Expand Up @@ -1084,7 +1084,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {

Value self = adaptor.getSelf();
auto selfTy = self.getType().template cast<RankedTensorType>();
auto selfTy = cast<RankedTensorType>(self.getType());

if (!selfTy)
return rewriter.notifyMatchFailure(
Expand All @@ -1095,7 +1095,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
op, "Only floating-point datatype legalization supported");

auto outType =
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
cast<TensorType>(getTypeConverter()->convertType(op.getType()));

Value expTensor = adaptor.getExponent();
if (expTensor.getType() != selfTy) {
Expand Down Expand Up @@ -2014,7 +2014,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
// Set up constants outside of loop
const int64_t sizeOfSliceInput = weightShape[1];
const int64_t sizeOfSliceKernel = weightShape[0] / groups;
auto inputShape = input.getType().cast<ShapedType>().getShape();
auto inputShape = cast<ShapedType>(input.getType()).getShape();

llvm::SmallVector<int64_t, 4> inputSize = {inputShape[0], inputShape[1],
inputShape[2], sizeOfSliceInput};
Expand All @@ -2023,7 +2023,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
llvm::SmallVector<Value> sliceValues;
Type outputType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(4, ShapedType::kDynamic),
resultType.cast<ShapedType>().getElementType());
cast<ShapedType>(resultType).getElementType());
for (int64_t i = 0; i < groups; i++) {
// Slice input
Value sliceInput = tosa::buildSlice(
Expand Down Expand Up @@ -3884,7 +3884,7 @@ class SimplifyAten_IndexPutImplOp
LogicalResult matchAndRewrite(Aten_IndexPutImplOp op,
PatternRewriter &rewriter) const override {

auto ty = op.getType().dyn_cast<BaseTensorType>();
auto ty = dyn_cast<BaseTensorType>(op.getType());
if (!ty || !ty.areAllSizesKnown()) {
return rewriter.notifyMatchFailure(op, "Required ranked tensor type");
}
Expand All @@ -3896,7 +3896,7 @@ class SimplifyAten_IndexPutImplOp
}
int64_t numSelfElements = shape[1];

auto valuesTy = op.getValues().getType().dyn_cast<BaseTensorType>();
auto valuesTy = dyn_cast<BaseTensorType>(op.getValues().getType());
if (!valuesTy || !valuesTy.areAllSizesKnown()) {
return rewriter.notifyMatchFailure(
op, "Required ranked tensor type for values");
Expand All @@ -3922,7 +3922,7 @@ class SimplifyAten_IndexPutImplOp
// Here, we know that self is 1xN, so we are only interested for the indices
// of the 2nd dimension.
auto indices = indicesList[1];
auto indicesTy = indices.getType().dyn_cast<BaseTensorType>();
auto indicesTy = dyn_cast<BaseTensorType>(indices.getType());
if (!indicesTy || !indicesTy.areAllSizesKnown()) {
return rewriter.notifyMatchFailure(
op, "Required ranked tensor type for indices");
Expand Down Expand Up @@ -4087,13 +4087,12 @@ LogicalResult SimplifyAtenOp<AtenConvolutionOp>::matchAndRewrite(
// %conv2d = AtenConvolution (%view) : (4D type) -> (4D type)
// %view2 = AtenViewOp (%conv2d) : (4D type) -> (3D type)

auto inputTy = adaptor.getInput().getType().cast<RankedTensorType>();
auto weightTy = adaptor.getWeight().getType().cast<RankedTensorType>();
auto outputTy = getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
auto inputTy = cast<RankedTensorType>(adaptor.getInput().getType());
auto weightTy = cast<RankedTensorType>(adaptor.getWeight().getType());
auto outputTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));

auto ty = op.getType().dyn_cast_or_null<BaseTensorType>();
auto ty = dyn_cast_or_null<BaseTensorType>(op.getType());
if (!ty || !ty.hasSizes())
return rewriter.notifyMatchFailure(
op, "unimplemented: input must have known sizes");
Expand Down Expand Up @@ -5661,14 +5660,14 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
const TypeConverter *typeConverter = this->getTypeConverter();

bool pinMemory;
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() &&
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) {
return rewriter.notifyMatchFailure(
op, "Unsupported pin_memory, should be either None or false");
}

if (!op.getDevice().getType().template isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getDevice().getType())) {
std::string device;
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
return rewriter.notifyMatchFailure(
Expand All @@ -5678,7 +5677,7 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
op, "unimplemented: device is expected to be none or cpu");
}

if (!op.getLayout().getType().template isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
int64_t tensorLayout;
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
return rewriter.notifyMatchFailure(
Expand All @@ -5688,7 +5687,7 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
op, "unimplemented: layout is expected to be strided");
}
// Only `none`, `contiguous` and `preserve` memory_format are supported.
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>()) {
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure(
Expand All @@ -5707,11 +5706,11 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
op, "unimplemented: size must be a ListConstruct");
SmallVector<Value> resultSize =
getTypeConvertedValues(rewriter, loc, typeConverter, size);
auto resultType = typeConverter->convertType(op.getType())
.template cast<RankedTensorType>();
auto resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));

DenseElementsAttr emptyVal;
if (op.getDtype().getType().template isa<Torch::NoneType>()) {
if (isa<Torch::NoneType>(op.getDtype().getType())) {
emptyVal = DenseFPElementsAttr::get(resultType, {0.0F});
} else {
int64_t dtypeInt;
Expand Down Expand Up @@ -5754,9 +5753,8 @@ LogicalResult ConvertAtenOp<AtenRepeatInterleaveTensorOp>::matchAndRewrite(
AtenRepeatInterleaveTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

auto outputTy = getTypeConverter()
->convertType(op.getType())
.dyn_cast<RankedTensorType>();
auto outputTy =
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!outputTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor type outputs permitted");
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ Value buildSlice(PatternRewriter &rewriter, Value &input,
rewriter, input.getLoc(),
RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(size.size(), ShapedType::kDynamic),
input.getType().cast<ShapedType>().getElementType()),
cast<ShapedType>(input.getType()).getElementType()),
input, start, size);
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7492,7 +7492,7 @@ class DecomposeAtenArcSinCosOp : public OpRewritePattern<ArcASinCosOp> {
LogicalResult matchAndRewrite(ArcASinCosOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto outType = op.getType().template dyn_cast<BaseTensorType>();
auto outType = dyn_cast<BaseTensorType>(op.getType());
if (!outType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
Expand Down

0 comments on commit 0ef5530

Please sign in to comment.