Skip to content

Commit

Permalink
[ET-VK] Splitting TILE_SIZE to TILE_SIZE_X and TILE_SIZE_Y in conv2d pw.
Browse files Browse the repository at this point in the history
Pull Request resolved: #7816

This diff splits the `TILE_SIZE` variable in the `conv2d_pw` GLSL code into `TILE_SIZE_X` and `TILE_SIZE_Y`. This change is made so tile size in different dimensions can be tuned separately.
ghstack-source-id: 263238734
@exported-using-ghexport

Differential Revision: [D68400783](https://our.internmc.facebook.com/intern/diff/D68400783/)

---------

Co-authored-by: Vivek Trivedi <[email protected]>
  • Loading branch information
pytorchbot and trivedivivek authored Jan 27, 2025
1 parent 0076965 commit 6db3f87
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
30 changes: 15 additions & 15 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

#define VEC4_T ${texel_type(DTYPE)}

#define TILE_SIZE ${TILE_SIZE}
#define TILE_SIZE_X ${TILE_SIZE_X}
#define TILE_SIZE_Y ${TILE_SIZE_Y}

#define op(X, A, B) ${OPERATOR}

Expand Down Expand Up @@ -43,7 +44,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
// 64 is the number of threads in the local wg
$num_shared = 64 * TILE_SIZE * TILE_SIZE
$num_shared = 64 * TILE_SIZE_X * TILE_SIZE_Y
shared ivec2 pos_shared[${num_shared}];

/*
Expand All @@ -52,8 +53,8 @@ shared ivec2 pos_shared[${num_shared}];
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
*/
void main() {
const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
const uint shared_mem_stride = 64;

const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
const ivec3 gpos = ivec3(
Expand All @@ -67,11 +68,10 @@ void main() {
// +--------+--------+
// | pos[2] | pos[3] |
// +--------+--------+
ivec2 pos[TILE_SIZE * TILE_SIZE];
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
for (int x = 0; x < TILE_SIZE; ++x) {
pos[i] = ivec2(
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
ivec2 pos[TILE_SIZE_X * TILE_SIZE_Y];
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
for (int x = 0; x < TILE_SIZE_X; ++x) {
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
i++;
}
Expand All @@ -86,14 +86,14 @@ void main() {
// Compute the index of the input texture that needs to be loaded for each
// output position. Note that negative indices can be produced indicating that
// the top-left element is in a region added by padding.
ivec2 ipos[TILE_SIZE * TILE_SIZE];
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y];
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
ipos[i] = pos[i] * stride - padding;
}

vec4 sum[TILE_SIZE * TILE_SIZE];
vec4 sum[TILE_SIZE_X * TILE_SIZE_Y];
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
for (int i = 1; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
sum[i] = sum[0];
}

Expand All @@ -109,7 +109,7 @@ void main() {
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));

#pragma unroll
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
// For 2x2 tile size algorithm works as follows.
// To explain the calculations below, the contents of one in_tex and the
Expand Down Expand Up @@ -151,7 +151,7 @@ void main() {
}
}

for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
Expand Down
3 changes: 2 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ conv2d_pw:
OPERATOR: X
NDIM: 3
DTYPE: float
TILE_SIZE: 2
TILE_SIZE_X: 2
TILE_SIZE_Y: 2
generate_variant_forall:
DTYPE:
- VALUE: half
Expand Down

0 comments on commit 6db3f87

Please sign in to comment.