Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support div_wrapping/rem_wrapping for numeric arithmetic kernels #7159

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions arrow-arith/src/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ pub fn div(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
arithmetic_op(Op::Div, lhs, rhs)
}

/// Perform `lhs / rhs`
///
/// Division by zero will result in an error, with exception to floating point numbers,
/// which instead follow the IEEE 754 rules
///
/// wrapping on overflow for signed_integer::MIN / -1
pub fn div_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
arithmetic_op(Op::DivWrapping, lhs, rhs)
}

/// Perform `lhs % rhs`
///
/// Overflow or division by zero will result in an error, with exception to
Expand All @@ -76,6 +86,16 @@ pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
arithmetic_op(Op::Rem, lhs, rhs)
}

/// Perform `lhs % rhs`
///
/// Division by zero will result in an error, with exception to floating point numbers,
/// which instead follow the IEEE 754 rules
///
/// wrapping on overflow for signed_integer::MIN % -1
pub fn rem_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
arithmetic_op(Op::RemWrapping, lhs, rhs)
}

macro_rules! neg_checked {
($t:ty, $a:ident) => {{
let array = $a
Expand Down Expand Up @@ -178,7 +198,9 @@ enum Op {
Sub,
MulWrapping,
Mul,
DivWrapping,
Div,
RemWrapping,
Rem,
}

Expand All @@ -188,8 +210,8 @@ impl std::fmt::Display for Op {
Op::AddWrapping | Op::Add => write!(f, "+"),
Op::SubWrapping | Op::Sub => write!(f, "-"),
Op::MulWrapping | Op::Mul => write!(f, "*"),
Op::Div => write!(f, "/"),
Op::Rem => write!(f, "%"),
Op::DivWrapping | Op::Div => write!(f, "/"),
Op::RemWrapping | Op::Rem => write!(f, "%"),
}
}
}
Expand Down Expand Up @@ -312,7 +334,9 @@ fn integer_op<T: ArrowPrimitiveType>(
Op::Sub => try_op!(l, l_s, r, r_s, l.sub_checked(r)),
Op::MulWrapping => op!(l, l_s, r, r_s, l.mul_wrapping(r)),
Op::Mul => try_op!(l, l_s, r, r_s, l.mul_checked(r)),
Op::DivWrapping => op!(l, l_s, r, r_s, l.div_wrapping(r)),
Op::Div => try_op!(l, l_s, r, r_s, l.div_checked(r)),
Op::RemWrapping => op!(l, l_s, r, r_s, l.mod_wrapping(r)),
Op::Rem => try_op!(l, l_s, r, r_s, l.mod_checked(r)),
};
Ok(Arc::new(array))
Expand All @@ -332,8 +356,8 @@ fn float_op<T: ArrowPrimitiveType>(
Op::AddWrapping | Op::Add => op!(l, l_s, r, r_s, l.add_wrapping(r)),
Op::SubWrapping | Op::Sub => op!(l, l_s, r, r_s, l.sub_wrapping(r)),
Op::MulWrapping | Op::Mul => op!(l, l_s, r, r_s, l.mul_wrapping(r)),
Op::Div => op!(l, l_s, r, r_s, l.div_wrapping(r)),
Op::Rem => op!(l, l_s, r, r_s, l.mod_wrapping(r)),
Op::DivWrapping | Op::Div => op!(l, l_s, r, r_s, l.div_wrapping(r)),
Op::RemWrapping | Op::Rem => op!(l, l_s, r, r_s, l.mod_wrapping(r)),
};
Ok(Arc::new(array))
}
Expand Down Expand Up @@ -788,7 +812,7 @@ fn decimal_op<T: DecimalType>(
.with_precision_and_scale(result_precision, result_scale)?
}

Op::Div => {
Op::Div | Op::DivWrapping => {
// Follow postgres and MySQL adding a fixed scale increment of 4
// s1 + 4
let result_scale = s1.saturating_add(4).min(T::MAX_SCALE);
Expand Down Expand Up @@ -819,7 +843,7 @@ fn decimal_op<T: DecimalType>(
.with_precision_and_scale(result_precision, result_scale)?
}

Op::Rem => {
Op::Rem | Op::RemWrapping => {
// max(s1, s2)
let result_scale = *s1.max(s2);
// min(p1-s1, p2 -s2) + max( s1,s2 )
Expand Down Expand Up @@ -1041,6 +1065,18 @@ mod tests {
err,
"Arithmetic overflow: Overflow happened on: -32768 / -1"
);
let result = div_wrapping(&a, &b).unwrap();
assert_eq!(result.as_ref(), &Int16Array::from(vec![-32768]));

let a = Int16Array::from(vec![i16::MIN]);
let b = Int16Array::from(vec![-1]);
let err = rem(&a, &b).unwrap_err().to_string();
assert_eq!(
err,
"Arithmetic overflow: Overflow happened on: -32768 % -1"
);
let result = rem_wrapping(&a, &b).unwrap();
assert_eq!(result.as_ref(), &Int16Array::from(vec![0]));

let a = Int16Array::from(vec![21]);
let b = Int16Array::from(vec![0]);
Expand Down
6 changes: 6 additions & 0 deletions arrow/benches/arithmetic_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,18 @@ fn add_benchmark(c: &mut Criterion) {
b.iter(|| criterion::black_box(mul_wrapping(&arr_a, &scalar).unwrap()))
});
c.bench_function(&format!("divide({null_density})"), |b| {
b.iter(|| criterion::black_box(div_wrapping(&arr_a, &arr_b).unwrap()))
});
c.bench_function(&format!("divide_checked({null_density})"), |b| {
b.iter(|| criterion::black_box(div(&arr_a, &arr_b).unwrap()))
});
c.bench_function(&format!("divide_scalar({null_density})"), |b| {
b.iter(|| criterion::black_box(div(&arr_a, &scalar).unwrap()))
});
c.bench_function(&format!("modulo({null_density})"), |b| {
b.iter(|| criterion::black_box(rem_wrapping(&arr_a, &arr_b).unwrap()))
});
c.bench_function(&format!("modulo_checked({null_density})"), |b| {
b.iter(|| criterion::black_box(rem(&arr_a, &arr_b).unwrap()))
});
c.bench_function(&format!("modulo_scalar({null_density})"), |b| {
Expand Down
Loading