Skip to content

Commit

Permalink
[AutoBump] Merge with 04aeb49
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Aug 13, 2024
2 parents 4a4ba3f + 04aeb49 commit e9a2c03
Show file tree
Hide file tree
Showing 20 changed files with 806 additions and 15 deletions.
2 changes: 2 additions & 0 deletions build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ function clean_build() {
}

function build_torch_mlir() {
# Disable LTC build for releases to avoid linker issues
export TORCH_MLIR_ENABLE_LTC=0
local torch_version="$1"
case $torch_version in
nightly)
Expand Down
19 changes: 19 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,25 @@ struct OpBinder {
return failure();
}

ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
StringRef nameSuffix) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr)
return success();
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
StringAttr stringAttr = element.dyn_cast<StringAttr>();
if (!stringAttr)
return failure();
values.push_back(stringAttr.getValue().str());
}
return success();
}
return failure();
}

ParseResult denseElementsAttr(ElementsAttr elementsattr,
StringRef nameSuffix) {
SmallString<64> name("torch.onnx.");
Expand Down
6 changes: 6 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

namespace mlir::torch::onnx_c {

Expand All @@ -20,6 +23,9 @@ Value createConstantIntList(OpBinder binder,

Type getQTorchTypeFromTorchIntType(Type ty);

LogicalResult OnnxLstmExpander(OpBinder binder,
ConversionPatternRewriter &rewriter);

bool areAllElementsDistinct(SmallVector<int64_t> array);

} // namespace mlir::torch::onnx_c
Expand Down
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7585,6 +7585,31 @@ def Torch_Aten__And__TensorOp : Torch_Op<"aten.__and__.Tensor", [
}];
}

def Torch_Aten__And__ScalarOp : Torch_Op<"aten.__and__.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__and__.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__And__ScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten__And__ScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_Aten__Or__TensorOp : Torch_Op<"aten.__or__.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TorchOnnxToTorch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch
DefaultDomainAtoF.cpp
DefaultDomainGtoP.cpp
DefaultDomainQtoZ.cpp
OnnxLstmExpander.cpp
Passes.cpp
Patterns.cpp
TorchOnnxToTorch.cpp
Expand Down
26 changes: 26 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1951,4 +1951,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
return success();
});
patterns.onOp(
"Einsum", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
SmallVector<Value> tensors;
std::string equation;
if (binder.tensorOperands(tensors, binder.op->getNumOperands()) ||
binder.customOpNameStringAttr(equation, "equation") ||
binder.tensorResultType(resultType))
return failure();
Type listElemType =
tensors[0]
.getType()
.cast<Torch::BaseTensorType>()
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.op->getLoc(), listType, tensors);
Value cstEquation = rewriter.create<Torch::ConstantStrOp>(
binder.getLoc(), rewriter.getType<Torch::StringType>(),
rewriter.getStringAttr(equation));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
rewriter.replaceOpWithNewOp<Torch::AtenEinsumOp>(
binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone);
return success();
});
}
1 change: 1 addition & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, operand);
return success();
});
patterns.onOp("LSTM", 1, onnx_c::OnnxLstmExpander);
patterns.onOp(
"LogSoftmax", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
});
patterns.onOp(
"Squeeze", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
"Squeeze", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data;
Value axes;
Expand Down
Loading

0 comments on commit e9a2c03

Please sign in to comment.