Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with fixes of aca33f17 (Oct 22, needs LLVM Oct 19) (92) #476

Merged
merged 33 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
aca33f1
[TorchToLinalg] Use Op with native channel order for quantized conv2d…
ubfx Oct 22, 2024
55ff110
[MLIR][TORCH] Only unroll prim loop-like ops within a `torch.shape.ca…
zjgarvey Oct 23, 2024
2f9a68c
Add canonicalization pattern for maxpool3d with indices op (#3704)
lingzhiz1998 Oct 23, 2024
d6feb21
Added support for Maxpool (Autopad) (#3774)
sriram-siloai Oct 23, 2024
1259e8a
Add Some Folders For Small Reshape Ops (#3813)
zjgarvey Oct 24, 2024
76209db
Update quantized matmul tests to DQ/Q format supported by fx_importer…
ubfx Oct 24, 2024
ad9dfe9
Fix clang warning about printf format (#3814)
dbabokin Oct 25, 2024
54d9e24
[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and…
bosko-syrmia Oct 25, 2024
2b01f8b
[Tosa] : Add support for negative indices in index.tensor and index.T…
sahas3 Oct 25, 2024
9ab2a15
[Torch] emit upsample_bilinear2d(.vec) ops (#3834)
qingyunqu Oct 30, 2024
16b3bd6
build: manually update PyTorch version and fix CI failure (#3830)
vivekkhandelwal1 Oct 30, 2024
6b58c89
Remove variable used for only assertion (#3837)
Max191 Oct 30, 2024
8b0bf2e
Bump LLVM to llvm/llvm-project@6c64c8a6f3f7 (#3818)
Max191 Oct 30, 2024
f3d07ce
[AutoBump] Merge with fixes of aca33f17 (Oct 22)
mgehre-amd Jan 24, 2025
191404b
[AutoBump] Merge with ad9dfe97 (Oct 25)
mgehre-amd Jan 27, 2025
442bc9e
[AutoBump] Merge with fixes of 54d9e240 (Oct 25)
mgehre-amd Jan 27, 2025
8ea73b7
Update xfail
mgehre-amd Jan 27, 2025
79d99ff
Bump llvm
mgehre-amd Jan 28, 2025
95486f8
Merge branch 'bump_to_aca33f17' into bump_to_ad9dfe97
mgehre-amd Jan 28, 2025
e454904
Merge branch 'bump_to_ad9dfe97' into bump_to_54d9e240
mgehre-amd Jan 28, 2025
6c213d7
Update xfail
mgehre-amd Jan 28, 2025
8e6a9e0
xfail
mgehre-amd Jan 28, 2025
5cbaae6
[AutoBump] Merge with fixes of 2b01f8b7 (Oct 26)
mgehre-amd Jan 28, 2025
38b2e89
Merge pull request #477 from Xilinx/bump_to_ad9dfe97
mgehre-amd Jan 29, 2025
157ed54
Merge pull request #478 from Xilinx/bump_to_54d9e240
mgehre-amd Jan 29, 2025
bb8a5fe
[AutoBump] Merge with fixes of 9ab2a150 (Oct 30)
mgehre-amd Jan 29, 2025
24fcff4
Merge commit '8b0bf2e2' into matthias.bump_to_8b0bf2e2
mgehre-amd Jan 29, 2025
9529dcc
Merge pull request #480 from Xilinx/bump_to_2b01f8b7
mgehre-amd Jan 29, 2025
d5ca144
Merge pull request #481 from Xilinx/bump_to_9ab2a150
mgehre-amd Jan 29, 2025
d55d7b9
Update xfail
mgehre-amd Jan 29, 2025
11e22b6
Merge pull request #482 from Xilinx/matthias.bump_to_8b0bf2e2
mgehre-amd Jan 29, 2025
2928f2e
Merge remote-tracking branch 'origin/feature/backport_ea1_ops' into b…
mgehre-amd Jan 29, 2025
0038abc
Update xfail
mgehre-amd Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 3644 files
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 44 files
+16 −0 BUILD.bazel
+4 −0 CMakeLists.txt
+2 −2 WORKSPACE.bazel
+1 −1 build_tools/llvm_version.txt
+1 −0 docs/generated/stablehlo_linalg_passes.md
+7 −0 docs/generated/stablehlo_passes.md
+1 −0 docs/generated/stablehlo_tosa_passes.md
+6 −2 docs/spec.md
+199 −0 rfcs/20241001-microscaling-formats.md
+19 −0 stablehlo/conversions/linalg/tests/miscellaneous.mlir
+9 −10 stablehlo/conversions/linalg/transforms/TypeConversion.cpp
+2 −19 stablehlo/dialect/Base.cpp
+3 −2 stablehlo/dialect/Base.td
+44 −4 stablehlo/dialect/StablehloOps.cpp
+5 −2 stablehlo/dialect/Version.cpp
+1 −1 stablehlo/dialect/Version.h
+49 −1 stablehlo/dialect/VhloBytecode.cpp
+1 −0 stablehlo/dialect/VhloDialect.td
+24 −0 stablehlo/dialect/VhloTypes.cpp
+12 −0 stablehlo/dialect/VhloTypes.td
+15 −43 stablehlo/reference/Tensor.cpp
+6 −4 stablehlo/reference/Types.cpp
+1 −1 stablehlo/testdata/igamma_float64_20_20_float64_20_20_chlo.mlir
+1 −1 stablehlo/testdata/igammac_float64_20_20_float64_20_20_chlo.mlir
+32 −0 stablehlo/tests/interpret/constant.mlir
+40 −8 stablehlo/tests/ops_stablehlo.mlir
+53 −53 stablehlo/tests/ops_stablehlo_quantized.mlir
+4 −0 stablehlo/tests/ops_stablehlo_roundtrip.mlir
+220 −0 stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
+550 −526 stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+2,936 −0 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir.bc
+32 −0 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
+35 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_7_0.mlir
+15 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir
+7 −2 stablehlo/transforms/CMakeLists.txt
+31 −2 stablehlo/transforms/PassUtils.cpp
+27 −12 stablehlo/transforms/PassUtils.h
+5 −0 stablehlo/transforms/Passes.h
+2 −0 stablehlo/transforms/Passes.td
+245 −7 stablehlo/transforms/StablehloAggressiveFolder.cpp
+98 −492 stablehlo/transforms/StablehloAggressiveSimplification.cpp
+281 −0 stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td
+7 −0 stablehlo/transforms/VhloToVersion.cpp
141 changes: 141 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,61 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
}];
}

def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -7352,6 +7407,7 @@ def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices",
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [
Expand Down Expand Up @@ -8079,6 +8135,7 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
Expand Down Expand Up @@ -9671,6 +9728,7 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
Expand All @@ -9695,6 +9753,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -14085,6 +14144,59 @@ def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [
}];
}

def Torch_AtenUpsampleBilinear2dOp : Torch_Op<"aten.upsample_bilinear2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size,
Torch_BoolType:$align_corners,
AnyTorchOptionalFloatType:$scales_h,
AnyTorchOptionalFloatType:$scales_w
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleBilinear2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenUpsampleBilinear2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenUpsampleBilinear2dVecOp : Torch_Op<"aten.upsample_bilinear2d.vec", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalListOfTorchIntType:$output_size,
Torch_BoolType:$align_corners,
AnyTorchOptionalListOfTorchFloatType:$scale_factors
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleBilinear2dVecOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenUpsampleBilinear2dVecOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -16861,6 +16973,35 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
}];
}

def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
Torch_BoolType:$self_is_result
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
32 changes: 29 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1087,9 +1087,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return rewriter.notifyMatchFailure(binder.op,
"auto_pad bind failure");
if (autoPad != "NOTSET")
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: auto_pad != NOTSET");

Torch::ValueTensorType resultTypeOut;
Value operand;
Expand Down Expand Up @@ -1136,13 +1133,42 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return rewriter.notifyMatchFailure(binder.op,
"dilations bind failure");

// set default padding
if (padding.empty())
padding.resize(spatial, 0);
if (strides.empty())
strides.resize(spatial, 1);
if (dilations.empty())
dilations.resize(spatial, 1);

auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());

// Padding for the beginning and ending along each spatial axis, it can
// take any value greater than or equal to 0. The value represent the
// number of pixels added to the beginning and end part of the
// corresponding axis. pads format should be as follow [x1_begin,
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
// at the beginning of axis i and xi_end, the number of pixels added at
// the end of axis i.
if (autoPad != "NOTSET" && autoPad != "VALID") {
const bool isSameLower = autoPad == "SAME_LOWER";
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
padding.resize_for_overwrite(2 * spatial);
for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) {
const int64_t dilatedKernelSize =
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
strides[dimIdx] -
1) *
strides[dimIdx] +
dilatedKernelSize - inputShape[dimIdx + 2];
totalPad = totalPad >= 0 ? totalPad : 0;
padding[dimIdx] =
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
padding[spatial + dimIdx] = totalPad - padding[dimIdx];
}
}

// If the padding is symmetric we can push the padding operation to the
// torch operator.
if (padding.size() == static_cast<size_t>(2 * spatial)) {
Expand Down
59 changes: 31 additions & 28 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,54 +1125,57 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
}

if (numGroups == 1 && inputZp) {
// The quantized version uses a different channel ordering so we need to
// permute the tensors in order to use the existing path. We should
// eventually directly support this channel ordering.
llvm::SmallVector<int64_t> inPerms, weightPerms;
inPerms.push_back(0); // N stays at the front for input.
// Then we expect the spatial dimensions
for (size_t i = 0; i < numSpatialDims; ++i) {
inPerms.push_back(i + 2);
weightPerms.push_back(i + 2);
}
inPerms.push_back(1);
weightPerms.append({1, 0});

paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter);
outputTensor =
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);

switch (numSpatialDims) {
case 2:
conv = rewriter
.create<linalg::Conv2DNhwcHwcfQOp>(
.create<linalg::Conv2DNchwFchwQOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight, inputZp, weightZp},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);
break;
case 3:
case 3: {
// The quantized version uses a different channel ordering so we need to
// permute the tensors in order to use the existing path. We should
// eventually directly support this channel ordering.
llvm::SmallVector<int64_t> inPerms, weightPerms;
inPerms.push_back(0); // N stays at the front for input.
// Then we expect the spatial dimensions
for (size_t i = 0; i < numSpatialDims; ++i) {
inPerms.push_back(i + 2);
weightPerms.push_back(i + 2);
}
inPerms.push_back(1);
weightPerms.append({1, 0});

paddedInput =
transposeValue(op.getLoc(), paddedInput, inPerms, rewriter);
weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter);
outputTensor =
transposeValue(op.getLoc(), outputTensor, inPerms, rewriter);

conv = rewriter
.create<linalg::Conv3DNdhwcDhwcfQOp>(
loc, outputTensor.getType(),
ValueRange{paddedInput, weight, inputZp, weightZp},
outputTensor, stridesAttr, dilationAttr)
.getResult(0);

llvm::SmallVector<int64_t> outPerms;
outPerms.push_back(0);
outPerms.push_back(inPerms.size() - 1);
for (size_t i = 0; i < numSpatialDims; ++i) {
outPerms.push_back(i + 1);
}
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);

break;
}
default:
return rewriter.notifyMatchFailure(
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
};

llvm::SmallVector<int64_t> outPerms;
outPerms.push_back(0);
outPerms.push_back(inPerms.size() - 1);
for (size_t i = 0; i < numSpatialDims; ++i) {
outPerms.push_back(i + 1);
}
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);

Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
Expand Down
Loading