Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK] Splitting TILE_SIZE to TILE_SIZE_X and TILE_SIZE_Y in conv2d pw. #7816

Merged
merged 6 commits into from
Jan 27, 2025
30 changes: 15 additions & 15 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

#define VEC4_T ${texel_type(DTYPE)}

#define TILE_SIZE ${TILE_SIZE}
#define TILE_SIZE_X ${TILE_SIZE_X}
#define TILE_SIZE_Y ${TILE_SIZE_Y}

#define op(X, A, B) ${OPERATOR}

Expand Down Expand Up @@ -43,7 +44,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
// 64 is the number of threads in the local wg
$num_shared = 64 * TILE_SIZE * TILE_SIZE
$num_shared = 64 * TILE_SIZE_X * TILE_SIZE_Y
shared ivec2 pos_shared[${num_shared}];

/*
Expand All @@ -52,8 +53,8 @@ shared ivec2 pos_shared[${num_shared}];
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
*/
void main() {
const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y);
const uint shared_mem_stride = 64;

const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
const ivec3 gpos = ivec3(
Expand All @@ -67,11 +68,10 @@ void main() {
// +--------+--------+
// | pos[2] | pos[3] |
// +--------+--------+
ivec2 pos[TILE_SIZE * TILE_SIZE];
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
for (int x = 0; x < TILE_SIZE; ++x) {
pos[i] = ivec2(
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
ivec2 pos[TILE_SIZE_X * TILE_SIZE_Y];
for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) {
for (int x = 0; x < TILE_SIZE_X; ++x) {
pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
i++;
}
Expand All @@ -86,14 +86,14 @@ void main() {
// Compute the index of the input texture that needs to be loaded for each
// output position. Note that negative indices can be produced indicating that
// the top-left element is in a region added by padding.
ivec2 ipos[TILE_SIZE * TILE_SIZE];
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
ivec2 ipos[TILE_SIZE_X * TILE_SIZE_Y];
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
ipos[i] = pos[i] * stride - padding;
}

vec4 sum[TILE_SIZE * TILE_SIZE];
vec4 sum[TILE_SIZE_X * TILE_SIZE_Y];
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
for (int i = 1; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
sum[i] = sum[0];
}

Expand All @@ -109,7 +109,7 @@ void main() {
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));

#pragma unroll
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
// For 2x2 tile size algorithm works as follows.
// To explain the calculations below, the contents of one in_tex and the
Expand Down Expand Up @@ -151,7 +151,7 @@ void main() {
}
}

for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) {
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
Expand Down
3 changes: 2 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ conv2d_pw:
OPERATOR: X
NDIM: 3
DTYPE: float
TILE_SIZE: 2
TILE_SIZE_X: 2
TILE_SIZE_Y: 2
generate_variant_forall:
DTYPE:
- VALUE: half
Expand Down
Loading