Skip to content

Commit

Permalink
add flag to trace quantization passes
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar committed Oct 29, 2024
1 parent 47d11e8 commit 5df74ee
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
5 changes: 5 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ Prints debug statements for the ``schedule`` pass.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Traces instructions replaced with a constant.

.. envvar:: MIGRAPHX_TRACE_QUANTIZATION

Set to "1", "enable", "enabled", "yes", or "true" to use.
Prints traces for any passes run during quantization.

.. envvar:: MIGRAPHX_8BITS_QUANTIZATION_PARAMS

Set to "1", "enable", "enabled", "yes", or "true" to use.
Expand Down
22 changes: 17 additions & 5 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_8BITS_QUANTIZATION_PARAMS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_QUANTIZATION)

tracer quant_tracer()
{
if(enabled(MIGRAPHX_TRACE_QUANTIZATION{}))
return tracer{std::cout};

return tracer{};
};

// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
Expand All @@ -61,7 +70,8 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{normalize_ops{},
optimize_module{{"quantizelinear", "dequantizelinear"}},
quantize_fp16_pass{ins_names},
optimize_module{{"quantizelinear", "dequantizelinear"}}});
optimize_module{{"quantizelinear", "dequantizelinear"}}},
quant_tracer());
}

void quantize_8bits(program& prog,
Expand All @@ -72,7 +82,7 @@ void quantize_8bits(program& prog,
{
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// avoid loss of precision.
run_passes(prog, {normalize_ops{}, optimize_module{}});
run_passes(prog, {normalize_ops{}, optimize_module{}}, quant_tracer());

std::shared_ptr<std::vector<std::pair<float, float>>> quant_8bit_params =
std::make_shared<std::vector<std::pair<float, float>>>();
Expand Down Expand Up @@ -106,7 +116,8 @@ void quantize_8bits(program& prog,

// pass to add capture argument op
std::size_t param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
run_passes(
prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}}, quant_tracer());
quant_8bit_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f);

Expand Down Expand Up @@ -150,7 +161,8 @@ void quantize_8bits(program& prog,
{quantize_8bits_pass{precision, *quant_8bit_params},
simplify_qdq{},
optimize_module{},
dead_code_elimination{}});
dead_code_elimination{}},
quant_tracer());
}

void quantize_int8(program& prog,
Expand All @@ -168,7 +180,7 @@ void quantize_int8(program& prog,

void quantize_int4_weights(program& prog)
{
run_passes(prog, {normalize_ops{}, optimize_module{}, quantize_int4_pass{}});
run_passes(prog, {normalize_ops{}, optimize_module{}, quantize_int4_pass{}}, quant_tracer());
}

void quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration)
Expand Down

0 comments on commit 5df74ee

Please sign in to comment.