From 44dae8b90ef232ea663727470dfbbe9daff6972d Mon Sep 17 00:00:00 2001 From: Wenlei Bao <142055114+wenlei-bao@users.noreply.github.com> Date: Thu, 19 Sep 2024 08:40:30 -0700 Subject: [PATCH] Adjust profiler space for SM89 (#1553) --- python/cutlass_library/generator.py | 51 ++++++++++++++++++----------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index c736551432..9f327154a6 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -4881,7 +4881,8 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): return layouts = [ - (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor) + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) ] math_instructions = [ @@ -4935,43 +4936,49 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): for math_inst in math_instructions: tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 128, 64], 6, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 64], 6, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), ] data_types = [ @@ -4981,6 +4988,12 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): DataType.f32, math_inst.element_accumulator ], + [ + math_inst.element_a, + math_inst.element_b, + DataType.bf16, + math_inst.element_accumulator + ], ] operations = []