diff --git a/native/spark-expr/src/math_funcs/round.rs b/native/spark-expr/src/math_funcs/round.rs index a47b7bc29..f5a149654 100644 --- a/native/spark-expr/src/math_funcs/round.rs +++ b/native/spark-expr/src/math_funcs/round.rs @@ -85,9 +85,10 @@ pub fn spark_round( let (precision, scale) = get_precision_scale(data_type); make_decimal_array(array, precision, scale, &f) } - DataType::Float32 | DataType::Float64 => { - Ok(ColumnarValue::Array(round(&[Arc::clone(array)])?)) - } + DataType::Float32 | DataType::Float64 => Ok(ColumnarValue::Array(round(&[ + Arc::clone(array), + args[1].to_array(array.len())?, + ])?)), dt => exec_err!("Not supported datatype for ROUND: {dt}"), }, ColumnarValue::Scalar(a) => match a { @@ -109,7 +110,7 @@ pub fn spark_round( make_decimal_scalar(a, precision, scale, &f) } ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar( - ScalarValue::try_from_array(&round(&[a.to_array()?])?, 0)?, + ScalarValue::try_from_array(&round(&[a.to_array()?, args[1].to_array(1)?])?, 0)?, )), dt => exec_err!("Not supported datatype for ROUND: {dt}"), }, @@ -135,3 +136,80 @@ fn decimal_round_f(scale: &i8, point: &i64) -> Box i128> { Box::new(move |x: i128| (x + x.signum() * half) / div) } } + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use crate::spark_round; + + use arrow::array::{Float32Array, Float64Array}; + use arrow_schema::DataType; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::ColumnarValue; + + #[test] + fn test_round_f32_array() -> Result<()> { + let args = vec![ + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123, + ]))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ]; + let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32)? else { + unreachable!() + }; + let floats = as_float32_array(&result)?; + let expected = Float32Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]); + assert_eq!(floats, &expected); + Ok(()) + } + + #[test] + fn test_round_f64_array() -> Result<()> { + let args = vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123, + ]))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ]; + let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64)? else { + unreachable!() + }; + let floats = as_float64_array(&result)?; + let expected = Float64Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]); + assert_eq!(floats, &expected); + Ok(()) + } + + #[test] + fn test_round_f32_scalar() -> Result<()> { + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(125.2345))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ]; + let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) = + spark_round(&args, &DataType::Float32)? + else { + unreachable!() + }; + assert_eq!(result, 125.23); + Ok(()) + } + + #[test] + fn test_round_f64_scalar() -> Result<()> { + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(125.2345))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ]; + let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) = + spark_round(&args, &DataType::Float64)? + else { + unreachable!() + }; + assert_eq!(result, 125.23); + Ok(()) + } +}