Skip to content

Commit

Permalink
[ET-VK][int4] patch 4-bit linear op for ensuring w-packed in/out
Browse files Browse the repository at this point in the history
Pull Request resolved: #8225

If the partitioner is using channels-packed setting for activations, then the checks will throw.

Remove the checks and conditionally re-pack the input/output tensors if they are not width-packed.
ghstack-source-id: 264952605
@exported-using-ghexport

Differential Revision: [D68813946](https://our.internmc.facebook.com/intern/diff/D68813946/)

---------

Co-authored-by: Nathanael See <[email protected]>
  • Loading branch information
pytorchbot and Nathanael See authored Feb 6, 2025
1 parent 8ec08f9 commit 8f0d797
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,6 @@ void check_q_4w_linear_args(
const int group_size_val = graph.extract_scalar<int>(group_size);
VK_CHECK_COND(K % group_size_val == 0);

VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);

VK_CHECK_COND(graph.has_standard_axis_map(mat1));
VK_CHECK_COND(graph.has_standard_axis_map(out));
}
Expand Down Expand Up @@ -320,13 +317,32 @@ void add_q_4w_linear_node(

const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);

ValueRef mat1_W_packed = mat1;
ValueRef out_W_packed = out;
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
// Create temporary tensors to store the width packed versions of mat1 and out
TmpTensor mat1_tmp(
&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
TmpTensor out_tmp(
&graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked);
if (storage_type == utils::kTexture3D) {
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
// Ensure mat1 is width packed
mat1_W_packed = mat1_tmp;
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
// Ensure out is packed correctly
out_W_packed = out_tmp;
}
}

vkapi::ParamsBindList ubos({});
ubos.append(graph.logical_limits_ubo(out));
ubos.append(graph.sizes_ubo(mat1));
ubos.append(graph.logical_limits_ubo(out_W_packed));
ubos.append(graph.sizes_ubo(mat1_W_packed));
ubos.append(graph.strides_ubo(mat2));
ubos.append(graph.strides_ubo(scales_and_zeros));

utils::uvec3 global_wg_size = graph.logical_limits_of(out);
utils::uvec3 global_wg_size = graph.logical_limits_of(out_W_packed);
utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);

graph.execute_nodes().emplace_back(new DispatchNode(
Expand All @@ -335,15 +351,20 @@ void add_q_4w_linear_node(
global_wg_size,
local_wg_size,
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
{{mat1_W_packed, mat2, scales_and_zeros},
vkapi::MemoryAccessType::READ}},
// Shader params buffers
ubos,
// Specialization Constants
{SV(group_size_val)},
// Resizing Logic
resize_q_4w_linear_node,
{}));
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(out) != WHCN::kWidthDim) {
viewFn(graph, {out_W_packed, graph.add_none(), out});
}
}

void linear_weight_int4(
Expand Down

0 comments on commit 8f0d797

Please sign in to comment.