From 739892e9d368b3e8dfd9a1d9b7e55da052b50fc0 Mon Sep 17 00:00:00 2001 From: Shiv Date: Sat, 11 Jan 2025 18:09:33 -0800 Subject: [PATCH 1/2] fix attn fusion when trailing pointwise has literal --- src/targets/gpu/fuse_mlir.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 65a27a76ad1..26c089e591d 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -872,7 +872,10 @@ struct find_mlir_standalone_attention_op if(contains(r.instructions, "trailing_pm")) { auto trailing_pm_ins = r.instructions["trailing_pm"]; + auto lit_map = create_param_map_with_literals( + &m_attn, trailing_pm_ins->module_inputs().front(), trailing_pm_ins->get_shape()); m_attn.add_params(trailing_pm_ins->inputs(), &map_main_to_mattn); + map_main_to_mattn.insert(lit_map.begin(), lit_map.end()); std::unordered_map map_pm_to_mattn(map_main_to_mattn); auto fused_pw_outs = m_attn .fuse(*trailing_pm_ins->module_inputs().front(), From 5b8d18c0685319f8d3f8df297c9cd79c2b7ac560 Mon Sep 17 00:00:00 2001 From: Shiv Date: Sat, 11 Jan 2025 18:20:31 -0800 Subject: [PATCH 2/2] license --- src/targets/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 26c089e591d..519955c2113 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal