diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index 5a55a0f3..1e8dfb75 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -222,6 +222,18 @@ where // nb. We assume `a` is likely already contiguous, so this will be cheap. let a_contig = a.to_contiguous_in(pool).auto_return(pool); let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice()); + + // Broadcast zero point to match new row count. + let a_quant: Option> = a_quant.map(|a_quant| { + a_quant + .zero_point + .iter() + .copied() + .cycle() + .take(a_matrix.size(0)) + .collect() + }); + let mut output = matmul_impl( pool, a_matrix.view(), @@ -230,7 +242,9 @@ where strategy, bias, alpha, - a_quant, + a_quant.as_ref().map(|zero_point| QuantParams { + zero_point: zero_point.as_slice(), + }), b_quant, )?; output.reshape(out_shape); @@ -1065,7 +1079,15 @@ mod tests { b_zero_point: Some(Tensor::from([3, 4])), expected_err: None, }, - // A input which is a row vector + // LHS batch input with vector zero point + Case { + a: Tensor::zeros(&[3, 2, 2]), + b: Tensor::from([[5, 6], [7, 8]]), + a_zero_point: Some(Tensor::from([1, 2])), + b_zero_point: Some(Tensor::from([3, 4])), + expected_err: None, + }, + // An input which is a row vector Case { a: Tensor::from([[1, 2, 3, 4]]), b: Tensor::from([[5, 6], [7, 8], [9, 10], [11, 12]]),