Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intial commit to target all reshaper opsops after mlir_op #3754

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,24 @@
return {upper_input, op_stream};
}

std::tuple<instruction_ref, std::vector<operation>>
get_fusable_output_op_stream(instruction_ref upper_output)
{
instruction_ref lower_output = upper_output;
std::vector<operation> op_stream;
while(contains(reshaper_names(), lower_output->name()))
{
operation op = lower_output->get_operator();
op_stream.push_back(op);

if(lower_output->outputs().size() > 1)
break;

lower_output = lower_output->outputs().at(0);
}
return {lower_output, op_stream};
}

void fuse_input_ops(module_ref mm,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins)
Expand All @@ -257,6 +275,20 @@
}
}

void fuse_output_ops(module_ref mm,
const std::vector<instruction_ref> &outputs)
{
size_t input_cnt = mm->get_parameters().size();

Check warning on line 281 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / tidy

unused variable 'input_cnt' [clang-diagnostic-unused-variable,-warnings-as-errors]

Check warning on line 281 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Variable 'input_cnt' is assigned a value that is never used. [unreadVariable]
for(instruction_ref output : outputs)
{
auto [lower_output, op_stream] = get_fusable_output_op_stream(output);
for(const auto& op : (op_stream))
{
lower_output = mm->add_instruction(op, lower_output->inputs());
}
}
}

std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm,
const std::vector<instruction_ref>& gemm_based_op_inputs,
Expand Down Expand Up @@ -1040,6 +1072,37 @@
}
};

struct find_mlir_single_output_reshaper_op
{
auto matcher() const
{
auto reshapes = reshaper_names();

Check warning on line 1079 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / tidy

the variable 'reshapes' is copy-constructed from a const reference but is only used as const reference; consider making it a const reference [performance-unnecessary-copy-initialization,-warnings-as-errors]
return match::name("gpu::mlir_op")(match::any_of[match::outputs()](match::name(reshapes).bind("output_reshaper")));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
static int find_count = 0;
auto ins = r.result;

if (ins->outputs().size() > 1)
return;

auto out_op = r.instructions["output_reshaper"];
auto* mlir_module = ins->module_inputs().front();
std::cout << "Found :" << find_count++ << " reshapers" << std::endl;

mpm.get_module().debug_print(ins);
mpm.get_module().debug_print(out_op);
mlir_module->debug_print();


fuse_output_ops(mlir_module, out_op>outputs());

Check warning on line 1100 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / tidy

invalid operands to binary expression ('instruction_ref' (aka '_List_iterator<migraphx::instruction>') and '(lambda at src/include/migraphx/matcher.hpp:587:12)') [clang-diagnostic-error]

Check warning on line 1100 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / tidy

use of undeclared identifier 'outputs'; did you mean 'match::outputs'? [clang-diagnostic-error]

mlir_module->debug_print();
}
};

} // namespace

#endif // MIGRAPHX_MLIR
Expand Down Expand Up @@ -1092,6 +1155,7 @@

match::find_matches(mpm, find_pointwise_mlir{});
match::find_matches(mpm, find_unpack_int4_mlir_op{});
match::find_matches(mpm, find_mlir_single_output_reshaper_op{});

#else
(void)mpm;
Expand Down
Loading