From 53d1b0d127d5ad0b06c06ed73251b913c54db9f3 Mon Sep 17 00:00:00 2001 From: trivedivivek <5340687+trivedivivek@users.noreply.github.com> Date: Tue, 28 Jan 2025 13:37:23 -0600 Subject: [PATCH] [ET-VK] Using shared memory to save position in conv2d dw output op. Differential Revision: D68400890 Pull Request resolved: https://github.com/pytorch/executorch/pull/7923 --- .../graph/ops/glsl/conv2d_dw_output_tile.glsl | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index c05c7e4450..5a42f50e91 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -20,6 +20,8 @@ #define BATCH_SIZE_Y ${BATCH_SIZE_Y} +#define LOCAL_WG_SIZE 64 + #define op(X, A, B) ${OPERATOR} #include "indexing_utils.h" @@ -38,6 +40,11 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +// For performance improvement, reduce register usage by caching positions in shared memory. +// Offset index by 1 every 16 points to avoid bank access conflict. +#define offset_pos_index(index) (index + ((index) >> 4)) +shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE)]; + /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. @@ -63,6 +70,8 @@ void main() { return; } + pos_shared[offset_pos_index(gl_LocalInvocationIndex)] = pos; + // Compute the index of the top-left element of the overlay region. Negative // indices indicate that the top-left element is in a region added by padding. const ivec2 ipos = pos.xy * stride - padding; @@ -109,18 +118,19 @@ void main() { for (int j = 0; j < TILE_SIZE; j++, kx++) { prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0); for (int s = 0; s < BATCH_SIZE_X; s++) { - sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]); + sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]); } } } } + const ivec3 out_pos = pos_shared[offset_pos_index(gl_LocalInvocationIndex)]; for (int y = 0; y < BATCH_SIZE_Y; y++) { for (int x = 0; x < BATCH_SIZE_X; x++) { - if (any(greaterThanEqual(ivec3(pos.x + x, pos.y + y, pos.z), out_limits))) { + if (any(greaterThanEqual(ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), out_limits))) { continue; } - imageStore(t_out, ivec3(pos.x + x, pos.y + y, pos.z), op(sum[y][x], out_min, out_max)); + imageStore(t_out, ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), op(sum[y][x], out_min, out_max)); } } }