diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index ff6bc7ada46..be8d89d4d87 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -610,9 +610,16 @@ struct mlir_program return "migraphx." + ins->name(); } - static value get_operator_value(const operation& op) + static value get_operator_value(instruction_ref ins) { - auto v = op.to_value(); + const operation& op = ins->get_operator(); + auto v = op.to_value(); + + // Reshape operator can have dim 0 or -1. + // Avoid passing those on to MLIR: + if(op.name() == "reshape") + v["dims"] = ins->get_shape().lens(); + if(op.name() == "convolution" or op.name() == "quant_convolution") { // Adjust symetrical padding @@ -668,7 +675,7 @@ struct mlir_program } auto name = get_name(ins); auto ops = create_operation_state(name); - ops.add_attribute_value(get_operator_value(ins->get_operator())); + ops.add_attribute_value(get_operator_value(ins)); if(ins->name() != "@return") ops.add_results({get_shape(ins)}); if(ins->name() == "@literal") diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index a2530053c8a..ae3726729d4 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -227,6 +227,32 @@ module { EXPECT(verify_mlir(m)); } +// The following test checks that a dimension -1, within reshape operator is handled properly.. +TEST_CASE(conv_reshape_dim_minus_one) +{ + const std::string mlir_output = R"__migraphx__( +module { + func.func @mlir_convolution_reshape(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x4x1x2xf32, 8x2x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} { + %0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1> + %1 = migraphx.reshape %0 {dims = [1, 4, 1, 2]} : <1x2x2x2xf32, 8x4x2x1> -> <1x4x1x2xf32, 8x2x2x1> + return %1 : !migraphx.shaped<1x4x1x2xf32, 8x2x2x1> + } +} +)__migraphx__"; + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); + auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}}); + auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w); + auto reshape = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 1, 2}}}), conv); + m.add_return({reshape}); + auto s = migraphx::gpu::dump_mlir(m); + // Skip test if MLIR is not enabled + if(s.empty()) + return; + CHECK(encode(s) == encode(mlir_output)); + EXPECT(verify_mlir(m)); +} + TEST_CASE(quant_dot_add) { const std::string mlir_output = R"__migraphx__(