Skip to content

Commit

Permalink
Add support for MatMulNBits (#3496)
Browse files Browse the repository at this point in the history
  • Loading branch information
music-dino authored Oct 10, 2024
1 parent 0851540 commit a1e3396
Show file tree
Hide file tree
Showing 16 changed files with 1,087 additions and 0 deletions.
192 changes: 192 additions & 0 deletions src/onnx/parse_matmulnbits.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/errors.hpp"
#include "migraphx/instruction_ref.hpp"
#include "migraphx/onnx/onnx_parser.hpp"
#include <cstddef>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

struct parse_matmulnbits : op_parser<parse_matmulnbits>
{
std::vector<op_desc> operators() const { return {{"MatMulNBits"}}; }

instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
const size_t n = parse_attribute(parser, info, "N");
const size_t k = parse_attribute(parser, info, "K");
const size_t bits = parse_attribute(parser, info, "bits");
const size_t block_size = parse_attribute(parser, info, "block_size");

if(bits != 4)
MIGRAPHX_THROW("MatMulNBits: bits only supported for value of 4, actual value " +
std::to_string(bits));

if(block_size < 16 or (block_size & (block_size - 1)) != 0)
MIGRAPHX_THROW("MatMulNBits: block_size must be a power of 2 and >=16, actual value " +
std::to_string(block_size));

const size_t n_blocks_per_col = (k + block_size - 1) / block_size;
const size_t blob_size = std::ceil(block_size * bits / 8.0f);

std::vector<size_t> expected_b_lens{n, n_blocks_per_col, blob_size};
if(args[1]->get_shape().lens() != expected_b_lens)
MIGRAPHX_THROW("MatMulNBits: Input B does not match expected dims: " +
to_string_range(expected_b_lens) +
". Actual dims: " + to_string_range(args[1]->get_shape().lens()));

std::vector<size_t> expected_scales_lens{n * n_blocks_per_col};
if(args[2]->get_shape().lens() != expected_scales_lens)
MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " +
to_string_range(expected_scales_lens) +
". Actual dims: " + to_string_range(args[2]->get_shape().lens()));

if(args.size() > 3)
{
std::vector<size_t> expected_zp_lens{
static_cast<size_t>(n * std::ceil(n_blocks_per_col * bits / 8.0f))};
if(args[3]->get_shape().lens() != expected_zp_lens)
MIGRAPHX_THROW("MatMulNBits: Input zero_points does not match expected dims: " +
to_string_range(expected_zp_lens) +
". Actual dims: " + to_string_range(args[3]->get_shape().lens()));
}

auto b = dequantize_b(info, n, k, block_size, args);
b = info.add_instruction(make_op("transpose", {{"permutation", {1, 0}}}), b);
return matmul(info, args[0], b);
}

private:
int parse_attribute(const onnx_parser& parser,
onnx_parser::node_info& info,
const std::string& attribute_name) const
{
if(not contains(info.attributes, attribute_name))
MIGRAPHX_THROW("MatMulNBits: Attribute " + attribute_name +
" required, but is missing");

return parser.parse_value(info.attributes[attribute_name]).at<int>();
}

instruction_ref dequantize_b(onnx_parser::node_info& info,
int n,
int k,
int block_size,
const std::vector<instruction_ref>& args) const
{
auto b = unpack(info, n, k, args[1]);

auto n_blocks_per_col = (k + block_size - 1) / block_size;
auto scales = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), args[2]);
scales = prepare_blockwise_dq_arg(info, n, k, block_size, scales);

instruction_ref zp;
if(args.size() == 4)
{
zp = unpack(info, n, n_blocks_per_col, args[3]);
zp = prepare_blockwise_dq_arg(info, n, k, block_size, zp);
}
else
{
zp = info.add_literal(literal{shape{shape::uint8_type, {1}}, {8}});
zp = info.add_instruction(
make_op("multibroadcast", {{"out_lens", b->get_shape().lens()}}), zp);
}
return info.add_instruction(make_op("dequantizelinear"), {b, scales, zp});
}

instruction_ref unpack(onnx_parser::node_info& info, int n, int dim1, instruction_ref x) const
{
x = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), x);
x = info.add_instruction(make_op("unpack_int4"), x);
if(x->get_shape().lens()[1] > dim1)
{
x = info.add_instruction(
make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {dim1}}}), x);
}
return x;
}

instruction_ref prepare_blockwise_dq_arg(
onnx_parser::node_info& info, int n, int k, int block_size, instruction_ref x) const
{
x = info.add_instruction(make_op("unsqueeze", {{"axes", {2}}}), x);

auto bc_lens = x->get_shape().lens();
bc_lens[2] = block_size;
x = info.add_instruction(make_op("multibroadcast", {{"out_lens", bc_lens}}), x);
x = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), x);

// Detect runt block
if(x->get_shape().lens()[1] > k)
{
x = info.add_instruction(
make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {k}}}), x);
}

return x;
}

instruction_ref matmul(onnx_parser::node_info& info, instruction_ref a, instruction_ref b) const
{
const auto a_rank = a->get_shape().ndim();
// B is always rank 2:
// If A is rank 1, unsqueeze A to make it rank 2 to prepare for dot
// If A is rank 2, just a regular dot
// If A is rank > 2, broadcast B to match outer dims of A to prepare for dot
if(a_rank == 1)
{
a = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), a);
}
else if(a_rank > 2)
{
auto b_lens = b->get_shape().lens();
auto b_bc_lens = a->get_shape().lens();
std::copy(b_lens.begin(), b_lens.end(), b_bc_lens.end() - 2);
b = info.add_instruction(make_op("multibroadcast", {{"out_lens", b_bc_lens}}), b);
}

auto dot = info.add_instruction(make_op("dot"), a, b);

if(a_rank == 1)
dot = info.add_instruction(
make_op("squeeze", {{"axes", {dot->get_shape().ndim() - 2}}}), dot);

return dot;
}
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
21 changes: 21 additions & 0 deletions src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct miopen_apply
add_select_module_op();
add_reshape_lazy_op();
add_scan_slice_op();
add_unpack_int4_op();
}

void copy_params() const
Expand Down Expand Up @@ -527,6 +528,26 @@ struct miopen_apply
ins, mod->insert_instruction(ins, ins->get_operator(), inputs));
});
}

void add_unpack_int4_op()
{
apply_map.emplace("unpack_int4", [=](instruction_ref ins) {
auto inputs = ins->inputs();
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> cpu_inputs;
auto gpu_inputs = ins->inputs();
std::transform(
gpu_inputs.begin(), gpu_inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) {
return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in);
});
cpu_inputs.front() =
mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs);
auto cpu_out = mod->insert_instruction(ins, ins->get_operator(), cpu_inputs);
auto gpu_out =
mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output);
return mod->replace_instruction(ins, gpu_out);
});
}
};

void lowering::apply(module_pass_manager& mpm) const
Expand Down
144 changes: 144 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9020,6 +9020,150 @@ def qlinearmatmul_3D_test():
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])


@onnx_test()
def matmulnbits_mm_test():
a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT,
[2, 16])
b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8,
[4, 1, 8])
scales = onnx.helper.make_tensor_value_info("scales",
onnx.TensorProto.FLOAT, [4])
zp = onnx.helper.make_tensor_value_info("zp", onnx.TensorProto.UINT8, [4])
c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, [2, 4])

node = onnx.helper.make_node("MatMulNBits",
inputs=["a", "b", "scales", "zp"],
outputs=["c"],
bits=4,
block_size=16,
K=16,
N=4,
domain='com.microsoft')
return ([node], [a, b, scales, zp], [c])


@onnx_test()
def matmulnbits_mm2_test():
a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT,
[2, 33])
b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8,
[2, 3, 8])
scales = onnx.helper.make_tensor_value_info("scales",
onnx.TensorProto.FLOAT, [6])
c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, [2, 2])

node = onnx.helper.make_node("MatMulNBits",
inputs=["a", "b", "scales"],
outputs=["c"],
bits=4,
block_size=16,
K=33,
N=2,
domain='com.microsoft')
return ([node], [a, b, scales], [c])


@onnx_test()
def matmulnbits_vm_test():
a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, [20])
b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8,
[3, 2, 8])
scales = onnx.helper.make_tensor_value_info("scales",
onnx.TensorProto.FLOAT, [6])
zp = onnx.helper.make_tensor_value_info("zp", onnx.TensorProto.UINT8, [3])
c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, [3])

node = onnx.helper.make_node("MatMulNBits",
inputs=["a", "b", "scales", "zp"],
outputs=["c"],
bits=4,
block_size=16,
K=20,
N=3,
domain='com.microsoft')
return ([node], [a, b, scales, zp], [c])


@onnx_test()
def matmulnbits_bmm_test():
a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT,
[2, 3, 8])
b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8,
[2, 1, 8])
scales = onnx.helper.make_tensor_value_info("scales",
onnx.TensorProto.FLOAT, [2])
c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT,
[2, 3, 2])

node = onnx.helper.make_node("MatMulNBits",
inputs=["a", "b", "scales"],
outputs=["c"],
bits=4,
block_size=16,
K=8,
N=2,
domain='com.microsoft')
return ([node], [a, b, scales], [c])


def matmulnbits_negative_test(bits=4,
block_size=16,
a_dims=[2, 16],
b_dims=[4, 1, 8],
scales_dims=[4],
zp_dims=[4],
out_dims=[2, 4]):
a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, a_dims)
b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, b_dims)
scales = onnx.helper.make_tensor_value_info("scales",
onnx.TensorProto.FLOAT,
scales_dims)
zp = onnx.helper.make_tensor_value_info("zp", onnx.TensorProto.UINT8,
zp_dims)
c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT,
out_dims)

node = onnx.helper.make_node("MatMulNBits",
inputs=["a", "b", "scales", "zp"],
outputs=["c"],
bits=bits,
block_size=block_size,
K=16,
N=4,
domain='com.microsoft')
return ([node], [a, b, scales, zp], [c])


@onnx_test()
def matmulnbits_invalid_bits_value_test():
return matmulnbits_negative_test(bits=5)


@onnx_test()
def matmulnbits_block_size_too_small_test():
return matmulnbits_negative_test(block_size=8)


@onnx_test()
def matmulnbits_block_size_not_power_of_two_test():
return matmulnbits_negative_test(block_size=20)


@onnx_test()
def matmulnbits_invalid_b_dims_test():
return matmulnbits_negative_test(b_dims=[4, 2, 8])


@onnx_test()
def matmulnbits_invalid_scales_dims_test():
return matmulnbits_negative_test(scales_dims=[3])


@onnx_test()
def matmulnbits_invalid_zp_dims_test():
return matmulnbits_negative_test(zp_dims=[5])


@onnx_test()
def qlinearmul_test():
a = helper.make_tensor_value_info('A', TensorProto.UINT8, [64])
Expand Down
Loading

0 comments on commit a1e3396

Please sign in to comment.