Skip to content

Commit

Permalink
[ET-VK] Using shared memory to save position in conv2d dw output op.
Browse files Browse the repository at this point in the history
This diff introduces a change to conv2d dw op to save output positions in shared memory, which reduces register usage and improves performance.

Differential Revision: [D68400890](https://our.internmc.facebook.com/intern/diff/D68400890/)

ghstack-source-id: 262823001
Pull Request resolved: #7923
  • Loading branch information
trivedivivek committed Jan 24, 2025
1 parent 0cbce05 commit 95d3072
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// macro to offset shared memory access index. Padding position index by 1 offset per 16 positions avoidd bank access conflict and thus improves performance.
#define offset_pos_index(index) (index + ((index) >> 4))

// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
// 64 is the number of threads in the local wg
shared ivec3 pos_shared[offset_pos_index(64)];

/*
* Computes a depthwise convolution. Each shader invocation calculates the
* output at a single output location.
Expand All @@ -63,6 +70,8 @@ void main() {
return;
}

pos_shared[offset_pos_index(gl_LocalInvocationIndex)] = pos;

// Compute the index of the top-left element of the overlay region. Negative
// indices indicate that the top-left element is in a region added by padding.
const ivec2 ipos = pos.xy * stride - padding;
Expand Down Expand Up @@ -109,18 +118,19 @@ void main() {
for (int j = 0; j < TILE_SIZE; j++, kx++) {
prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0);
for (int s = 0; s < BATCH_SIZE_X; s++) {
sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
}
}
}
}

const ivec3 out_pos = pos_shared[offset_pos_index(gl_LocalInvocationIndex)];
for (int y = 0; y < BATCH_SIZE_Y; y++) {
for (int x = 0; x < BATCH_SIZE_X; x++) {
if (any(greaterThanEqual(ivec3(pos.x + x, pos.y + y, pos.z), out_limits))) {
if (any(greaterThanEqual(ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), out_limits))) {
continue;
}
imageStore(t_out, ivec3(pos.x + x, pos.y + y, pos.z), op(sum[y][x], out_min, out_max));
imageStore(t_out, ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), op(sum[y][x], out_min, out_max));
}
}
}

0 comments on commit 95d3072

Please sign in to comment.