From 64d06eeb843327ba2ab0323996fb835e7808a860 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Sat, 1 Feb 2025 10:40:52 -0800 Subject: [PATCH 1/6] Refactor some decimal-related code and tests in preparation for adding Decimal32 and Decimal64 support --- arrow-cast/src/cast/decimal.rs | 75 +-- arrow-cast/src/cast/dictionary.rs | 86 ++- arrow-cast/src/cast/mod.rs | 711 +++++++++++-------------- arrow-data/src/data.rs | 6 +- arrow-integration-test/src/datatype.rs | 12 +- arrow-ipc/src/convert.rs | 18 +- arrow-ord/src/sort.rs | 297 +++-------- arrow-schema/src/ffi.rs | 12 +- arrow/benches/builder.rs | 2 +- arrow/tests/array_cast.rs | 2 +- 10 files changed, 467 insertions(+), 754 deletions(-) diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index fee9c007c0ae..94d990b31a8f 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -25,6 +25,8 @@ pub(crate) trait DecimalCast: Sized { fn to_i256(self) -> Option; fn from_decimal(n: T) -> Option; + + fn from_f64(n: f64) -> Option; } impl DecimalCast for i128 { @@ -39,6 +41,10 @@ impl DecimalCast for i128 { fn from_decimal(n: T) -> Option { n.to_i128() } + + fn from_f64(n: f64) -> Option { + n.to_i128() + } } impl DecimalCast for i256 { @@ -53,6 +59,10 @@ impl DecimalCast for i256 { fn from_decimal(n: T) -> Option { n.to_i256() } + + fn from_f64(n: f64) -> Option { + i256::from_f64(n) + } } pub(crate) fn cast_decimal_to_decimal_error( @@ -463,7 +473,7 @@ where Ok(Arc::new(result)) } -pub(crate) fn cast_floating_point_to_decimal128( +pub(crate) fn cast_floating_point_to_decimal( array: &PrimitiveArray, precision: u8, scale: i8, @@ -471,78 +481,33 @@ pub(crate) fn cast_floating_point_to_decimal128( ) -> Result where ::Native: AsPrimitive, + D: DecimalType + ArrowPrimitiveType, + M: ArrowNativeTypeOp + DecimalCast, { let mul = 10_f64.powi(scale as i32); if cast_options.safe { array - .unary_opt::<_, Decimal128Type>(|v| { - (mul * v.as_()) - .round() - .to_i128() - .filter(|v| Decimal128Type::is_valid_decimal_precision(*v, precision)) + .unary_opt::<_, D>(|v| { + M::from_f64::((mul * v.as_()).round()) + .filter(|v| D::is_valid_decimal_precision(*v, precision)) }) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) } else { array - .try_unary::<_, Decimal128Type, _>(|v| { - (mul * v.as_()) - .round() - .to_i128() + .try_unary::<_, D, _>(|v| { + M::from_f64::((mul * v.as_()).round()) .ok_or_else(|| { ArrowError::CastError(format!( "Cannot cast to {}({}, {}). Overflowing on {:?}", - Decimal128Type::PREFIX, + D::PREFIX, precision, scale, v )) }) - .and_then(|v| { - Decimal128Type::validate_decimal_precision(v, precision).map(|_| v) - }) - })? - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) - } -} - -pub(crate) fn cast_floating_point_to_decimal256( - array: &PrimitiveArray, - precision: u8, - scale: i8, - cast_options: &CastOptions, -) -> Result -where - ::Native: AsPrimitive, -{ - let mul = 10_f64.powi(scale as i32); - - if cast_options.safe { - array - .unary_opt::<_, Decimal256Type>(|v| { - i256::from_f64((v.as_() * mul).round()) - .filter(|v| Decimal256Type::is_valid_decimal_precision(*v, precision)) - }) - .with_precision_and_scale(precision, scale) - .map(|a| Arc::new(a) as ArrayRef) - } else { - array - .try_unary::<_, Decimal256Type, _>(|v| { - i256::from_f64((v.as_() * mul).round()) - .ok_or_else(|| { - ArrowError::CastError(format!( - "Cannot cast to {}({}, {}). Overflowing on {:?}", - Decimal256Type::PREFIX, - precision, - scale, - v - )) - }) - .and_then(|v| { - Decimal256Type::validate_decimal_precision(v, precision).map(|_| v) - }) + .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v)) })? .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs index ec0ab346f997..0fd8997555fb 100644 --- a/arrow-cast/src/cast/dictionary.rs +++ b/arrow-cast/src/cast/dictionary.rs @@ -214,50 +214,20 @@ pub(crate) fn cast_to_dictionary( UInt16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), UInt32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), UInt64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - Decimal128(p, s) => { - let dict = pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - )?; - let dict = dict - .as_dictionary::() - .downcast_dict::() - .ok_or_else(|| { - ArrowError::ComputeError( - "Internal Error: Cannot cast dict to Decimal128Array".to_string(), - ) - })?; - let value = dict.values().clone(); - // Set correct precision/scale - let value = value.with_precision_and_scale(p, s)?; - Ok(Arc::new(DictionaryArray::::try_new( - dict.keys().clone(), - Arc::new(value), - )?)) - } - Decimal256(p, s) => { - let dict = pack_numeric_to_dictionary::( - array, - dict_value_type, - cast_options, - )?; - let dict = dict - .as_dictionary::() - .downcast_dict::() - .ok_or_else(|| { - ArrowError::ComputeError( - "Internal Error: Cannot cast dict to Decimal256Array".to_string(), - ) - })?; - let value = dict.values().clone(); - // Set correct precision/scale - let value = value.with_precision_and_scale(p, s)?; - Ok(Arc::new(DictionaryArray::::try_new( - dict.keys().clone(), - Arc::new(value), - )?)) - } + Decimal128(p, s) => pack_decimal_to_dictionary::( + array, + dict_value_type, + p, + s, + cast_options, + ), + Decimal256(p, s) => pack_decimal_to_dictionary::( + array, + dict_value_type, + p, + s, + cast_options, + ), Float16 => { pack_numeric_to_dictionary::(array, dict_value_type, cast_options) } @@ -359,6 +329,34 @@ where Ok(Arc::new(b.finish())) } +pub(crate) fn pack_decimal_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, + precision: u8, + scale: i8, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, + D: DecimalType + ArrowPrimitiveType, + M: ArrowNativeTypeOp + DecimalCast, +{ + let dict = pack_numeric_to_dictionary::(array, dict_value_type, cast_options)?; + let dict = dict + .as_dictionary::() + .downcast_dict::>() + .ok_or_else(|| { + ArrowError::ComputeError(format!("Internal Error: Cannot cast dict to {}", D::PREFIX)) + })?; + let value = dict.values().clone(); + // Set correct precision/scale + let value = value.with_precision_and_scale(precision, scale)?; + Ok(Arc::new(DictionaryArray::::try_new( + dict.keys().clone(), + Arc::new(value), + )?)) +} + pub(crate) fn string_view_to_dictionary( array: &dyn Array, ) -> Result diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index c448ad96016c..3572357c2960 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -234,7 +234,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { | BinaryView, ) => true, (Utf8 | LargeUtf8, Utf8View) => true, - (BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View ) => true, + (BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View) => true, (Utf8View | Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16, (_, Utf8 | LargeUtf8) => from_type.is_primitive(), (_, Utf8View) => from_type.is_numeric(), @@ -830,6 +830,7 @@ pub fn cast_with_options( (Map(_, ordered1), Map(_, ordered2)) if ordered1 == ordered2 => { cast_map_values(array.as_map(), to_type, cast_options, ordered1.to_owned()) } + // Decimal to decimal, same width (Decimal128(p1, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal_same_type::( array.as_primitive(), @@ -850,6 +851,7 @@ pub fn cast_with_options( cast_options, ) } + // Decimal to decimal, different width (Decimal128(_, s1), Decimal256(p2, s2)) => { cast_decimal_to_decimal::( array.as_primitive(), @@ -868,315 +870,51 @@ pub fn cast_with_options( cast_options, ) } + // Decimal to non-decimal (Decimal128(_, scale), _) if !to_type.is_temporal() => { - // cast decimal to other type - match to_type { - UInt8 => cast_decimal_to_integer::( - array, - 10_i128, - *scale, - cast_options, - ), - UInt16 => cast_decimal_to_integer::( - array, - 10_i128, - *scale, - cast_options, - ), - UInt32 => cast_decimal_to_integer::( - array, - 10_i128, - *scale, - cast_options, - ), - UInt64 => cast_decimal_to_integer::( - array, - 10_i128, - *scale, - cast_options, - ), - Int8 => cast_decimal_to_integer::( - array, - 10_i128, - *scale, - cast_options, - ), - Int16 => cast_decimal_to_integer::( - array, - 10_i128, - *scale, - cast_options, - ), - Int32 => cast_decimal_to_integer::( - array, - 10_i128, - *scale, - cast_options, - ), - Int64 => cast_decimal_to_integer::( - array, - 10_i128, - *scale, - cast_options, - ), - Float32 => cast_decimal_to_float::(array, |x| { - (x as f64 / 10_f64.powi(*scale as i32)) as f32 - }), - Float64 => cast_decimal_to_float::(array, |x| { - x as f64 / 10_f64.powi(*scale as i32) - }), - Utf8View => value_to_string_view(array, cast_options), - Utf8 => value_to_string::(array, cast_options), - LargeUtf8 => value_to_string::(array, cast_options), - Null => Ok(new_null_array(to_type, array.len())), - _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported" - ))), - } + cast_from_decimal::( + array, + 10_i128, + scale, + from_type, + to_type, + |x: i128| x as f64, + cast_options, + ) } (Decimal256(_, scale), _) if !to_type.is_temporal() => { - // cast decimal to other type - match to_type { - UInt8 => cast_decimal_to_integer::( - array, - i256::from_i128(10_i128), - *scale, - cast_options, - ), - UInt16 => cast_decimal_to_integer::( - array, - i256::from_i128(10_i128), - *scale, - cast_options, - ), - UInt32 => cast_decimal_to_integer::( - array, - i256::from_i128(10_i128), - *scale, - cast_options, - ), - UInt64 => cast_decimal_to_integer::( - array, - i256::from_i128(10_i128), - *scale, - cast_options, - ), - Int8 => cast_decimal_to_integer::( - array, - i256::from_i128(10_i128), - *scale, - cast_options, - ), - Int16 => cast_decimal_to_integer::( - array, - i256::from_i128(10_i128), - *scale, - cast_options, - ), - Int32 => cast_decimal_to_integer::( - array, - i256::from_i128(10_i128), - *scale, - cast_options, - ), - Int64 => cast_decimal_to_integer::( - array, - i256::from_i128(10_i128), - *scale, - cast_options, - ), - Float32 => cast_decimal_to_float::(array, |x| { - (x.to_f64().unwrap() / 10_f64.powi(*scale as i32)) as f32 - }), - Float64 => cast_decimal_to_float::(array, |x| { - x.to_f64().unwrap() / 10_f64.powi(*scale as i32) - }), - Utf8View => value_to_string_view(array, cast_options), - Utf8 => value_to_string::(array, cast_options), - LargeUtf8 => value_to_string::(array, cast_options), - Null => Ok(new_null_array(to_type, array.len())), - _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported" - ))), - } + cast_from_decimal::( + array, + i256::from_i128(10_i128), + scale, + from_type, + to_type, + |x: i256| x.to_f64().unwrap(), + cast_options, + ) } + // Non-decimal to decimal (_, Decimal128(precision, scale)) if !from_type.is_temporal() => { - // cast data to decimal - match from_type { - UInt8 => cast_integer_to_decimal::<_, Decimal128Type, _>( - array.as_primitive::(), - *precision, - *scale, - 10_i128, - cast_options, - ), - UInt16 => cast_integer_to_decimal::<_, Decimal128Type, _>( - array.as_primitive::(), - *precision, - *scale, - 10_i128, - cast_options, - ), - UInt32 => cast_integer_to_decimal::<_, Decimal128Type, _>( - array.as_primitive::(), - *precision, - *scale, - 10_i128, - cast_options, - ), - UInt64 => cast_integer_to_decimal::<_, Decimal128Type, _>( - array.as_primitive::(), - *precision, - *scale, - 10_i128, - cast_options, - ), - Int8 => cast_integer_to_decimal::<_, Decimal128Type, _>( - array.as_primitive::(), - *precision, - *scale, - 10_i128, - cast_options, - ), - Int16 => cast_integer_to_decimal::<_, Decimal128Type, _>( - array.as_primitive::(), - *precision, - *scale, - 10_i128, - cast_options, - ), - Int32 => cast_integer_to_decimal::<_, Decimal128Type, _>( - array.as_primitive::(), - *precision, - *scale, - 10_i128, - cast_options, - ), - Int64 => cast_integer_to_decimal::<_, Decimal128Type, _>( - array.as_primitive::(), - *precision, - *scale, - 10_i128, - cast_options, - ), - Float32 => cast_floating_point_to_decimal128( - array.as_primitive::(), - *precision, - *scale, - cast_options, - ), - Float64 => cast_floating_point_to_decimal128( - array.as_primitive::(), - *precision, - *scale, - cast_options, - ), - Utf8View | Utf8 => cast_string_to_decimal::( - array, - *precision, - *scale, - cast_options, - ), - LargeUtf8 => cast_string_to_decimal::( - array, - *precision, - *scale, - cast_options, - ), - Null => Ok(new_null_array(to_type, array.len())), - _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported" - ))), - } + cast_to_decimal::( + array, + 10_i128, + precision, + scale, + from_type, + to_type, + cast_options, + ) } (_, Decimal256(precision, scale)) if !from_type.is_temporal() => { - // cast data to decimal - match from_type { - UInt8 => cast_integer_to_decimal::<_, Decimal256Type, _>( - array.as_primitive::(), - *precision, - *scale, - i256::from_i128(10_i128), - cast_options, - ), - UInt16 => cast_integer_to_decimal::<_, Decimal256Type, _>( - array.as_primitive::(), - *precision, - *scale, - i256::from_i128(10_i128), - cast_options, - ), - UInt32 => cast_integer_to_decimal::<_, Decimal256Type, _>( - array.as_primitive::(), - *precision, - *scale, - i256::from_i128(10_i128), - cast_options, - ), - UInt64 => cast_integer_to_decimal::<_, Decimal256Type, _>( - array.as_primitive::(), - *precision, - *scale, - i256::from_i128(10_i128), - cast_options, - ), - Int8 => cast_integer_to_decimal::<_, Decimal256Type, _>( - array.as_primitive::(), - *precision, - *scale, - i256::from_i128(10_i128), - cast_options, - ), - Int16 => cast_integer_to_decimal::<_, Decimal256Type, _>( - array.as_primitive::(), - *precision, - *scale, - i256::from_i128(10_i128), - cast_options, - ), - Int32 => cast_integer_to_decimal::<_, Decimal256Type, _>( - array.as_primitive::(), - *precision, - *scale, - i256::from_i128(10_i128), - cast_options, - ), - Int64 => cast_integer_to_decimal::<_, Decimal256Type, _>( - array.as_primitive::(), - *precision, - *scale, - i256::from_i128(10_i128), - cast_options, - ), - Float32 => cast_floating_point_to_decimal256( - array.as_primitive::(), - *precision, - *scale, - cast_options, - ), - Float64 => cast_floating_point_to_decimal256( - array.as_primitive::(), - *precision, - *scale, - cast_options, - ), - Utf8View | Utf8 => cast_string_to_decimal::( - array, - *precision, - *scale, - cast_options, - ), - LargeUtf8 => cast_string_to_decimal::( - array, - *precision, - *scale, - cast_options, - ), - Null => Ok(new_null_array(to_type, array.len())), - _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported" - ))), - } + cast_to_decimal::( + array, + i256::from_i128(10_i128), + precision, + scale, + from_type, + to_type, + cast_options, + ) } (Struct(_), Struct(to_fields)) => { let array = array.as_struct(); @@ -2192,6 +1930,150 @@ pub fn cast_with_options( } } +fn cast_from_decimal( + array: &dyn Array, + base: D::Native, + scale: &i8, + from_type: &DataType, + to_type: &DataType, + as_float: F, + cast_options: &CastOptions, +) -> Result +where + D: DecimalType + ArrowPrimitiveType, + ::Native: ArrowNativeTypeOp + ToPrimitive, + F: Fn(D::Native) -> f64, +{ + use DataType::*; + // cast decimal to other type + match to_type { + UInt8 => cast_decimal_to_integer::(array, base, *scale, cast_options), + UInt16 => cast_decimal_to_integer::(array, base, *scale, cast_options), + UInt32 => cast_decimal_to_integer::(array, base, *scale, cast_options), + UInt64 => cast_decimal_to_integer::(array, base, *scale, cast_options), + Int8 => cast_decimal_to_integer::(array, base, *scale, cast_options), + Int16 => cast_decimal_to_integer::(array, base, *scale, cast_options), + Int32 => cast_decimal_to_integer::(array, base, *scale, cast_options), + Int64 => cast_decimal_to_integer::(array, base, *scale, cast_options), + Float32 => cast_decimal_to_float::(array, |x| { + (as_float(x) / 10_f64.powi(*scale as i32)) as f32 + }), + Float64 => cast_decimal_to_float::(array, |x| { + as_float(x) / 10_f64.powi(*scale as i32) + }), + Utf8View => value_to_string_view(array, cast_options), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + } +} + +fn cast_to_decimal( + array: &dyn Array, + base: M, + precision: &u8, + scale: &i8, + from_type: &DataType, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result +where + D: DecimalType + ArrowPrimitiveType, + M: ArrowNativeTypeOp + DecimalCast, + u8: num::traits::AsPrimitive, + u16: num::traits::AsPrimitive, + u32: num::traits::AsPrimitive, + u64: num::traits::AsPrimitive, + i8: num::traits::AsPrimitive, + i16: num::traits::AsPrimitive, + i32: num::traits::AsPrimitive, + i64: num::traits::AsPrimitive, +{ + use DataType::*; + // cast data to decimal + match from_type { + UInt8 => cast_integer_to_decimal::<_, D, M>( + array.as_primitive::(), + *precision, + *scale, + base, + cast_options, + ), + UInt16 => cast_integer_to_decimal::<_, D, _>( + array.as_primitive::(), + *precision, + *scale, + base, + cast_options, + ), + UInt32 => cast_integer_to_decimal::<_, D, _>( + array.as_primitive::(), + *precision, + *scale, + base, + cast_options, + ), + UInt64 => cast_integer_to_decimal::<_, D, _>( + array.as_primitive::(), + *precision, + *scale, + base, + cast_options, + ), + Int8 => cast_integer_to_decimal::<_, D, _>( + array.as_primitive::(), + *precision, + *scale, + base, + cast_options, + ), + Int16 => cast_integer_to_decimal::<_, D, _>( + array.as_primitive::(), + *precision, + *scale, + base, + cast_options, + ), + Int32 => cast_integer_to_decimal::<_, D, _>( + array.as_primitive::(), + *precision, + *scale, + base, + cast_options, + ), + Int64 => cast_integer_to_decimal::<_, D, _>( + array.as_primitive::(), + *precision, + *scale, + base, + cast_options, + ), + Float32 => cast_floating_point_to_decimal::<_, D, _>( + array.as_primitive::(), + *precision, + *scale, + cast_options, + ), + Float64 => cast_floating_point_to_decimal::<_, D, _>( + array.as_primitive::(), + *precision, + *scale, + cast_options, + ), + Utf8View | Utf8 => { + cast_string_to_decimal::(array, *precision, *scale, cast_options) + } + LargeUtf8 => cast_string_to_decimal::(array, *precision, *scale, cast_options), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to {to_type:?} not supported" + ))), + } +} + /// Get the time unit as a multiple of a second const fn time_unit_multiple(unit: &TimeUnit) -> i64 { match unit { @@ -2527,7 +2409,7 @@ mod tests { }; } - fn create_decimal_array( + fn create_decimal128_array( array: Vec>, precision: u8, scale: i8, @@ -2596,7 +2478,7 @@ mod tests { Some(-3123456), None, ]; - let array = create_decimal_array(array, 20, 4).unwrap(); + let array = create_decimal128_array(array, 20, 4).unwrap(); // decimal128 to decimal128 let input_type = DataType::Decimal128(20, 4); let output_type = DataType::Decimal128(20, 3); @@ -2681,7 +2563,7 @@ mod tests { let output_type = DataType::Decimal128(20, 4); assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; - let array = create_decimal_array(array, 20, 3).unwrap(); + let array = create_decimal128_array(array, 20, 3).unwrap(); generate_cast_test_case!( &array, Decimal128Array, @@ -2695,7 +2577,7 @@ mod tests { ); // negative test let array = vec![Some(123456), None]; - let array = create_decimal_array(array, 10, 0).unwrap(); + let array = create_decimal128_array(array, 10, 0).unwrap(); let result_safe = cast(&array, &DataType::Decimal128(2, 2)); assert!(result_safe.is_ok()); let options = CastOptions { @@ -2719,7 +2601,7 @@ mod tests { ); assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; - let array = create_decimal_array(array, p, s).unwrap(); + let array = create_decimal128_array(array, p, s).unwrap(); let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap(); assert_eq!(cast_array.data_type(), &output_type); } @@ -2735,7 +2617,7 @@ mod tests { ); assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; - let array = create_decimal_array(array, p, s).unwrap(); + let array = create_decimal128_array(array, p, s).unwrap(); let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap(); assert_eq!(cast_array.data_type(), &output_type); } @@ -2747,7 +2629,7 @@ mod tests { assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(i128::MAX)]; - let array = create_decimal_array(array, 38, 3).unwrap(); + let array = create_decimal128_array(array, 38, 3).unwrap(); let result = cast_with_options( &array, &output_type, @@ -2767,7 +2649,7 @@ mod tests { assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(i128::MAX)]; - let array = create_decimal_array(array, 38, 3).unwrap(); + let array = create_decimal128_array(array, 38, 3).unwrap(); let result = cast_with_options( &array, &output_type, @@ -2786,7 +2668,7 @@ mod tests { let output_type = DataType::Decimal256(20, 4); assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; - let array = create_decimal_array(array, 20, 3).unwrap(); + let array = create_decimal128_array(array, 20, 3).unwrap(); generate_cast_test_case!( &array, Decimal256Array, @@ -2888,96 +2770,103 @@ mod tests { ); } + macro_rules! generate_decimal_to_numeric_cast_test_case { + ($INPUT_ARRAY: expr) => { + // u8 + generate_cast_test_case!( + $INPUT_ARRAY, + UInt8Array, + &DataType::UInt8, + vec![Some(1_u8), Some(2_u8), Some(3_u8), None, Some(5_u8)] + ); + // u16 + generate_cast_test_case!( + $INPUT_ARRAY, + UInt16Array, + &DataType::UInt16, + vec![Some(1_u16), Some(2_u16), Some(3_u16), None, Some(5_u16)] + ); + // u32 + generate_cast_test_case!( + $INPUT_ARRAY, + UInt32Array, + &DataType::UInt32, + vec![Some(1_u32), Some(2_u32), Some(3_u32), None, Some(5_u32)] + ); + // u64 + generate_cast_test_case!( + $INPUT_ARRAY, + UInt64Array, + &DataType::UInt64, + vec![Some(1_u64), Some(2_u64), Some(3_u64), None, Some(5_u64)] + ); + // i8 + generate_cast_test_case!( + $INPUT_ARRAY, + Int8Array, + &DataType::Int8, + vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] + ); + // i16 + generate_cast_test_case!( + $INPUT_ARRAY, + Int16Array, + &DataType::Int16, + vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] + ); + // i32 + generate_cast_test_case!( + $INPUT_ARRAY, + Int32Array, + &DataType::Int32, + vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] + ); + // i64 + generate_cast_test_case!( + $INPUT_ARRAY, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f32 + generate_cast_test_case!( + $INPUT_ARRAY, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32) + ] + ); + // f64 + generate_cast_test_case!( + $INPUT_ARRAY, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64) + ] + ); + }; + } + #[test] - fn test_cast_decimal_to_numeric() { + fn test_cast_decimal128_to_numeric() { let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; - let array = create_decimal_array(value_array, 38, 2).unwrap(); - // u8 - generate_cast_test_case!( - &array, - UInt8Array, - &DataType::UInt8, - vec![Some(1_u8), Some(2_u8), Some(3_u8), None, Some(5_u8)] - ); - // u16 - generate_cast_test_case!( - &array, - UInt16Array, - &DataType::UInt16, - vec![Some(1_u16), Some(2_u16), Some(3_u16), None, Some(5_u16)] - ); - // u32 - generate_cast_test_case!( - &array, - UInt32Array, - &DataType::UInt32, - vec![Some(1_u32), Some(2_u32), Some(3_u32), None, Some(5_u32)] - ); - // u64 - generate_cast_test_case!( - &array, - UInt64Array, - &DataType::UInt64, - vec![Some(1_u64), Some(2_u64), Some(3_u64), None, Some(5_u64)] - ); - // i8 - generate_cast_test_case!( - &array, - Int8Array, - &DataType::Int8, - vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] - ); - // i16 - generate_cast_test_case!( - &array, - Int16Array, - &DataType::Int16, - vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] - ); - // i32 - generate_cast_test_case!( - &array, - Int32Array, - &DataType::Int32, - vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] - ); - // i64 - generate_cast_test_case!( - &array, - Int64Array, - &DataType::Int64, - vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] - ); - // f32 - generate_cast_test_case!( - &array, - Float32Array, - &DataType::Float32, - vec![ - Some(1.25_f32), - Some(2.25_f32), - Some(3.25_f32), - None, - Some(5.25_f32) - ] - ); - // f64 - generate_cast_test_case!( - &array, - Float64Array, - &DataType::Float64, - vec![ - Some(1.25_f64), - Some(2.25_f64), - Some(3.25_f64), - None, - Some(5.25_f64) - ] - ); + let array = create_decimal128_array(value_array, 38, 2).unwrap(); + + generate_decimal_to_numeric_cast_test_case!(&array); // overflow test: out of range of max u8 let value_array: Vec> = vec![Some(51300)]; - let array = create_decimal_array(value_array, 38, 2).unwrap(); + let array = create_decimal128_array(value_array, 38, 2).unwrap(); let casted_array = cast_with_options( &array, &DataType::UInt8, @@ -3004,7 +2893,7 @@ mod tests { // overflow test: out of range of max i8 let value_array: Vec> = vec![Some(24400)]; - let array = create_decimal_array(value_array, 38, 2).unwrap(); + let array = create_decimal128_array(value_array, 38, 2).unwrap(); let casted_array = cast_with_options( &array, &DataType::Int8, @@ -3041,7 +2930,7 @@ mod tests { Some(112345678), Some(112345679), ]; - let array = create_decimal_array(value_array, 38, 2).unwrap(); + let array = create_decimal128_array(value_array, 38, 2).unwrap(); generate_cast_test_case!( &array, Float32Array, @@ -3068,7 +2957,7 @@ mod tests { Some(112345678901234568), Some(112345678901234560), ]; - let array = create_decimal_array(value_array, 38, 2).unwrap(); + let array = create_decimal128_array(value_array, 38, 2).unwrap(); generate_cast_test_case!( &array, Float64Array, @@ -8381,7 +8270,7 @@ mod tests { let output_type = DataType::Decimal128(20, -1); assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(1123450), Some(2123455), Some(3123456), None]; - let input_decimal_array = create_decimal_array(array, 20, 0).unwrap(); + let input_decimal_array = create_decimal128_array(array, 20, 0).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; generate_cast_test_case!( &array, @@ -8439,7 +8328,7 @@ mod tests { let output_type = DataType::Decimal128(10, -2); assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(123)]; - let input_decimal_array = create_decimal_array(array, 10, -1).unwrap(); + let input_decimal_array = create_decimal128_array(array, 10, -1).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; generate_cast_test_case!(&array, Decimal128Array, &output_type, vec![Some(12_i128),]); @@ -8449,7 +8338,7 @@ mod tests { assert_eq!("1200", decimal_arr.value_as_string(0)); let array = vec![Some(125)]; - let input_decimal_array = create_decimal_array(array, 10, -1).unwrap(); + let input_decimal_array = create_decimal128_array(array, 10, -1).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; generate_cast_test_case!(&array, Decimal128Array, &output_type, vec![Some(13_i128),]); @@ -8465,7 +8354,7 @@ mod tests { let output_type = DataType::Decimal256(10, 5); assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(123456), Some(-123456)]; - let input_decimal_array = create_decimal_array(array, 10, 3).unwrap(); + let input_decimal_array = create_decimal128_array(array, 10, 3).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; let hundred = i256::from_i128(100); @@ -9287,15 +9176,15 @@ mod tests { test_decimal_to_string::( DataType::Utf8View, - create_decimal_array(array128.clone(), 7, 3).unwrap(), + create_decimal128_array(array128.clone(), 7, 3).unwrap(), ); test_decimal_to_string::( DataType::Utf8, - create_decimal_array(array128.clone(), 7, 3).unwrap(), + create_decimal128_array(array128.clone(), 7, 3).unwrap(), ); test_decimal_to_string::( DataType::LargeUtf8, - create_decimal_array(array128, 7, 3).unwrap(), + create_decimal128_array(array128, 7, 3).unwrap(), ); test_decimal_to_string::( @@ -9969,7 +9858,7 @@ mod tests { #[test] fn test_decimal_to_decimal_throw_error_on_precision_overflow_same_scale() { let array = vec![Some(123456789)]; - let array = create_decimal_array(array, 24, 2).unwrap(); + let array = create_decimal128_array(array, 24, 2).unwrap(); let input_type = DataType::Decimal128(24, 2); let output_type = DataType::Decimal128(6, 2); assert!(can_cast_types(&input_type, &output_type)); @@ -10015,7 +9904,7 @@ mod tests { #[test] fn test_decimal_to_decimal_throw_error_on_precision_overflow_lower_scale() { let array = vec![Some(123456789)]; - let array = create_decimal_array(array, 24, 4).unwrap(); + let array = create_decimal128_array(array, 24, 4).unwrap(); let input_type = DataType::Decimal128(24, 4); let output_type = DataType::Decimal128(6, 2); assert!(can_cast_types(&input_type, &output_type)); @@ -10032,7 +9921,7 @@ mod tests { #[test] fn test_decimal_to_decimal_throw_error_on_precision_overflow_greater_scale() { let array = vec![Some(123456789)]; - let array = create_decimal_array(array, 24, 2).unwrap(); + let array = create_decimal128_array(array, 24, 2).unwrap(); let input_type = DataType::Decimal128(24, 2); let output_type = DataType::Decimal128(6, 3); assert!(can_cast_types(&input_type, &output_type)); @@ -10049,7 +9938,7 @@ mod tests { #[test] fn test_decimal_to_decimal_throw_error_on_precision_overflow_diff_type() { let array = vec![Some(123456789)]; - let array = create_decimal_array(array, 24, 2).unwrap(); + let array = create_decimal128_array(array, 24, 2).unwrap(); let input_type = DataType::Decimal128(24, 2); let output_type = DataType::Decimal256(6, 2); assert!(can_cast_types(&input_type, &output_type)); diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 4314b550d680..dce1a853c22f 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -84,6 +84,8 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff | DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) | DataType::Date32 | DataType::Time32(_) | DataType::Date64 @@ -140,10 +142,6 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff DataType::FixedSizeList(_, _) | DataType::Struct(_) | DataType::RunEndEncoded(_, _) => { [empty_buffer, MutableBuffer::new(0)] } - DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => [ - MutableBuffer::new(capacity * mem::size_of::()), - empty_buffer, - ], DataType::Union(_, mode) => { let type_ids = MutableBuffer::new(capacity * mem::size_of::()); match mode { diff --git a/arrow-integration-test/src/datatype.rs b/arrow-integration-test/src/datatype.rs index e45e94c24e07..24e02c8430c7 100644 --- a/arrow-integration-test/src/datatype.rs +++ b/arrow-integration-test/src/datatype.rs @@ -60,14 +60,12 @@ pub fn data_type_from_json(json: &serde_json::Value) -> Result { _ => 128, // Default bit width }; - if bit_width == 128 { - Ok(DataType::Decimal128(precision, scale)) - } else if bit_width == 256 { - Ok(DataType::Decimal256(precision, scale)) - } else { - Err(ArrowError::ParseError( + match bit_width { + 128 => Ok(DataType::Decimal128(precision, scale)), + 256 => Ok(DataType::Decimal256(precision, scale)), + _ => Err(ArrowError::ParseError( "Decimal bit_width invalid".to_string(), - )) + )), } } Some(s) if s == "floatingpoint" => match map.get("precision") { diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index 37c5a19439c1..aadeb7703371 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -454,18 +454,12 @@ pub(crate) fn get_data_type(field: crate::Field, may_be_dictionary: bool) -> Dat crate::Type::Decimal => { let fsb = field.type_as_decimal().unwrap(); let bit_width = fsb.bitWidth(); - if bit_width == 128 { - DataType::Decimal128( - fsb.precision().try_into().unwrap(), - fsb.scale().try_into().unwrap(), - ) - } else if bit_width == 256 { - DataType::Decimal256( - fsb.precision().try_into().unwrap(), - fsb.scale().try_into().unwrap(), - ) - } else { - panic!("Unexpected decimal bit width {bit_width}") + let precision: u8 = fsb.precision().try_into().unwrap(); + let scale: i8 = fsb.scale().try_into().unwrap(); + match bit_width { + 128 => DataType::Decimal128(precision, scale), + 256 => DataType::Decimal256(precision, scale), + _ => panic!("Unexpected decimal bit width {bit_width}"), } } crate::Type::Union => { diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index 92ee93e1b656..fd3603b3729f 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -795,10 +795,18 @@ mod tests { use rand::seq::SliceRandom; use rand::{Rng, RngCore, SeedableRng}; - fn create_decimal128_array(data: Vec>) -> Decimal128Array { + fn create_decimal_array( + data: Vec>, + precision: u8, + scale: i8, + ) -> PrimitiveArray { data.into_iter() - .collect::() - .with_precision_and_scale(23, 6) + .map(|x| match x { + None => None, + Some(y) => T::Native::from_usize(y), + }) + .collect::>() + .with_precision_and_scale(precision, scale) .unwrap() } @@ -809,13 +817,15 @@ mod tests { .unwrap() } - fn test_sort_to_indices_decimal128_array( - data: Vec>, + fn test_sort_to_indices_decimal_array( + data: Vec>, options: Option, limit: Option, expected_data: Vec, + precision: u8, + scale: i8, ) { - let output = create_decimal128_array(data); + let output = create_decimal_array::(data, precision, scale); let expected = UInt32Array::from(expected_data); let output = sort_to_indices(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); assert_eq!(output, expected) @@ -833,14 +843,16 @@ mod tests { assert_eq!(output, expected) } - fn test_sort_decimal128_array( - data: Vec>, + fn test_sort_decimal_array( + data: Vec>, options: Option, limit: Option, - expected_data: Vec>, + expected_data: Vec>, + p: u8, + s: i8, ) { - let output = create_decimal128_array(data); - let expected = Arc::new(create_decimal128_array(expected_data)) as ArrayRef; + let output = create_decimal_array::(data, p, s); + let expected = Arc::new(create_decimal_array::(expected_data, p, s)) as ArrayRef; let output = match limit { Some(_) => sort_limit(&(Arc::new(output) as ArrayRef), options, limit).unwrap(), _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(), @@ -1921,17 +1933,18 @@ mod tests { ); } - #[test] - fn test_sort_indices_decimal128() { + fn test_sort_indices_decimal(precision: u8, scale: i8) { // decimal default - test_sort_to_indices_decimal128_array( + test_sort_to_indices_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], None, None, vec![0, 6, 4, 2, 3, 5, 1], + precision, + scale, ); // decimal descending - test_sort_to_indices_decimal128_array( + test_sort_to_indices_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: true, @@ -1939,9 +1952,11 @@ mod tests { }), None, vec![1, 5, 3, 2, 4, 0, 6], + precision, + scale, ); // decimal null_first and descending - test_sort_to_indices_decimal128_array( + test_sort_to_indices_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: true, @@ -1949,9 +1964,11 @@ mod tests { }), None, vec![0, 6, 1, 5, 3, 2, 4], + precision, + scale, ); // decimal null_first - test_sort_to_indices_decimal128_array( + test_sort_to_indices_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: false, @@ -1959,16 +1976,20 @@ mod tests { }), None, vec![0, 6, 4, 2, 3, 5, 1], + precision, + scale, ); // limit - test_sort_to_indices_decimal128_array( + test_sort_to_indices_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], None, Some(3), vec![0, 6, 4], + precision, + scale, ); // limit descending - test_sort_to_indices_decimal128_array( + test_sort_to_indices_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: true, @@ -1976,9 +1997,11 @@ mod tests { }), Some(3), vec![1, 5, 3], + precision, + scale, ); // limit descending null_first - test_sort_to_indices_decimal128_array( + test_sort_to_indices_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: true, @@ -1986,9 +2009,11 @@ mod tests { }), Some(3), vec![0, 6, 1], + precision, + scale, ); // limit null_first - test_sort_to_indices_decimal128_array( + test_sort_to_indices_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: false, @@ -1996,85 +2021,19 @@ mod tests { }), Some(3), vec![0, 6, 4], + precision, + scale, ); } #[test] - fn test_sort_indices_decimal256() { - let data = vec![ - None, - Some(i256::from_i128(5)), - Some(i256::from_i128(2)), - Some(i256::from_i128(3)), - Some(i256::from_i128(1)), - Some(i256::from_i128(4)), - None, - ]; + fn test_sort_indices_decimal128() { + test_sort_indices_decimal::(23, 6); + } - // decimal default - test_sort_to_indices_decimal256_array(data.clone(), None, None, vec![0, 6, 4, 2, 3, 5, 1]); - // decimal descending - test_sort_to_indices_decimal256_array( - data.clone(), - Some(SortOptions { - descending: true, - nulls_first: false, - }), - None, - vec![1, 5, 3, 2, 4, 0, 6], - ); - // decimal null_first and descending - test_sort_to_indices_decimal256_array( - data.clone(), - Some(SortOptions { - descending: true, - nulls_first: true, - }), - None, - vec![0, 6, 1, 5, 3, 2, 4], - ); - // decimal null_first - test_sort_to_indices_decimal256_array( - data.clone(), - Some(SortOptions { - descending: false, - nulls_first: true, - }), - None, - vec![0, 6, 4, 2, 3, 5, 1], - ); - // limit - test_sort_to_indices_decimal256_array(data.clone(), None, Some(3), vec![0, 6, 4]); - // limit descending - test_sort_to_indices_decimal256_array( - data.clone(), - Some(SortOptions { - descending: true, - nulls_first: false, - }), - Some(3), - vec![1, 5, 3], - ); - // limit descending null_first - test_sort_to_indices_decimal256_array( - data.clone(), - Some(SortOptions { - descending: true, - nulls_first: true, - }), - Some(3), - vec![0, 6, 1], - ); - // limit null_first - test_sort_to_indices_decimal256_array( - data, - Some(SortOptions { - descending: false, - nulls_first: true, - }), - Some(3), - vec![0, 6, 4], - ); + #[test] + fn test_sort_indices_decimal256() { + test_sort_indices_decimal::(53, 6); } #[test] @@ -2127,17 +2086,18 @@ mod tests { ); } - #[test] - fn test_sort_decimal128() { + fn test_sort_decimal(precision: u8, scale: i8) { // decimal default - test_sort_decimal128_array( + test_sort_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], None, None, vec![None, None, Some(1), Some(2), Some(3), Some(4), Some(5)], + precision, + scale, ); // decimal descending - test_sort_decimal128_array( + test_sort_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: true, @@ -2145,9 +2105,11 @@ mod tests { }), None, vec![Some(5), Some(4), Some(3), Some(2), Some(1), None, None], + precision, + scale, ); // decimal null_first and descending - test_sort_decimal128_array( + test_sort_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: true, @@ -2155,9 +2117,11 @@ mod tests { }), None, vec![None, None, Some(5), Some(4), Some(3), Some(2), Some(1)], + precision, + scale, ); // decimal null_first - test_sort_decimal128_array( + test_sort_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: false, @@ -2165,16 +2129,20 @@ mod tests { }), None, vec![None, None, Some(1), Some(2), Some(3), Some(4), Some(5)], + precision, + scale, ); // limit - test_sort_decimal128_array( + test_sort_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], None, Some(3), vec![None, None, Some(1)], + precision, + scale, ); // limit descending - test_sort_decimal128_array( + test_sort_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: true, @@ -2182,9 +2150,11 @@ mod tests { }), Some(3), vec![Some(5), Some(4), Some(3)], + precision, + scale, ); // limit descending null_first - test_sort_decimal128_array( + test_sort_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: true, @@ -2192,9 +2162,11 @@ mod tests { }), Some(3), vec![None, None, Some(5)], + precision, + scale, ); // limit null_first - test_sort_decimal128_array( + test_sort_decimal_array::( vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], Some(SortOptions { descending: false, @@ -2202,118 +2174,19 @@ mod tests { }), Some(3), vec![None, None, Some(1)], + precision, + scale, ); } + #[test] + fn test_sort_decimal128() { + test_sort_decimal::(23, 6); + } + #[test] fn test_sort_decimal256() { - let data = vec![ - None, - Some(i256::from_i128(5)), - Some(i256::from_i128(2)), - Some(i256::from_i128(3)), - Some(i256::from_i128(1)), - Some(i256::from_i128(4)), - None, - ]; - // decimal default - test_sort_decimal256_array( - data.clone(), - None, - None, - [None, None, Some(1), Some(2), Some(3), Some(4), Some(5)] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), - ); - // decimal descending - test_sort_decimal256_array( - data.clone(), - Some(SortOptions { - descending: true, - nulls_first: false, - }), - None, - [Some(5), Some(4), Some(3), Some(2), Some(1), None, None] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), - ); - // decimal null_first and descending - test_sort_decimal256_array( - data.clone(), - Some(SortOptions { - descending: true, - nulls_first: true, - }), - None, - [None, None, Some(5), Some(4), Some(3), Some(2), Some(1)] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), - ); - // decimal null_first - test_sort_decimal256_array( - data.clone(), - Some(SortOptions { - descending: false, - nulls_first: true, - }), - None, - [None, None, Some(1), Some(2), Some(3), Some(4), Some(5)] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), - ); - // limit - test_sort_decimal256_array( - data.clone(), - None, - Some(3), - [None, None, Some(1)] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), - ); - // limit descending - test_sort_decimal256_array( - data.clone(), - Some(SortOptions { - descending: true, - nulls_first: false, - }), - Some(3), - [Some(5), Some(4), Some(3)] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), - ); - // limit descending null_first - test_sort_decimal256_array( - data.clone(), - Some(SortOptions { - descending: true, - nulls_first: true, - }), - Some(3), - [None, None, Some(5)] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), - ); - // limit null_first - test_sort_decimal256_array( - data, - Some(SortOptions { - descending: false, - nulls_first: true, - }), - Some(3), - [None, None, Some(1)] - .iter() - .map(|v| v.map(i256::from_i128)) - .collect(), - ); + test_sort_decimal::(53, 6); } #[test] diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs index 96c80974982c..f6fcc8f275b3 100644 --- a/arrow-schema/src/ffi.rs +++ b/arrow-schema/src/ffi.rs @@ -510,9 +510,6 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { DataType::Decimal128(parsed_precision, parsed_scale) }, [precision, scale, bits] => { - if *bits != "128" && *bits != "256" { - return Err(ArrowError::CDataInterface("Only 128/256 bit wide decimal is supported in the Rust implementation".to_string())); - } let parsed_precision = precision.parse::().map_err(|_| { ArrowError::CDataInterface( "The decimal type requires an integer precision".to_string(), @@ -523,10 +520,11 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { "The decimal type requires an integer scale".to_string(), ) })?; - if *bits == "128" { - DataType::Decimal128(parsed_precision, parsed_scale) - } else { - DataType::Decimal256(parsed_precision, parsed_scale) + let parsed_bits = bits.parse::().unwrap_or(0); + match parsed_bits { + 128 => DataType::Decimal128(parsed_precision, parsed_scale), + 256 => DataType::Decimal256(parsed_precision, parsed_scale), + _ => return Err(ArrowError::CDataInterface("Only 128- and 256- bit wide decimals are supported in the Rust implementation".to_string())), } } _ => { diff --git a/arrow/benches/builder.rs b/arrow/benches/builder.rs index 2776924d8ee9..4f5f38eadfcb 100644 --- a/arrow/benches/builder.rs +++ b/arrow/benches/builder.rs @@ -126,7 +126,7 @@ fn bench_decimal128(c: &mut Criterion) { } fn bench_decimal256(c: &mut Criterion) { - c.bench_function("bench_decimal128_builder", |b| { + c.bench_function("bench_decimal256_builder", |b| { b.iter(|| { let mut rng = rand::rng(); let mut decimal_builder = Decimal256Builder::with_capacity(BATCH_SIZE); diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs index ef5ca6041700..26ef4635c851 100644 --- a/arrow/tests/array_cast.rs +++ b/arrow/tests/array_cast.rs @@ -411,7 +411,7 @@ fn make_dictionary_utf8() -> ArrayRef { Arc::new(b.finish()) } -fn create_decimal_array( +fn create_decimal128_array( array: Vec>, precision: u8, scale: i8, From 49907a99c3111e01ad04f3e09ce81677f9d4b64f Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Sat, 1 Feb 2025 10:58:03 -0800 Subject: [PATCH 2/6] Fixed symbol --- arrow/tests/array_cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs index 26ef4635c851..da7d37fc48a4 100644 --- a/arrow/tests/array_cast.rs +++ b/arrow/tests/array_cast.rs @@ -261,7 +261,7 @@ fn get_arrays_of_all_types() -> Vec { Arc::new(DurationMillisecondArray::from(vec![1000, 2000])), Arc::new(DurationMicrosecondArray::from(vec![1000, 2000])), Arc::new(DurationNanosecondArray::from(vec![1000, 2000])), - Arc::new(create_decimal_array(vec![Some(1), Some(2), Some(3)], 38, 0).unwrap()), + Arc::new(create_decimal128_array(vec![Some(1), Some(2), Some(3)], 38, 0).unwrap()), make_dictionary_primitive::(vec![1, 2]), make_dictionary_primitive::(vec![1, 2]), make_dictionary_primitive::(vec![1, 2]), From 5ffbdc91e8b6d61738122437fcfc4f98b8c099f0 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Thu, 6 Feb 2025 15:21:04 -0800 Subject: [PATCH 3/6] Apply PR feedback --- arrow-cast/src/cast/decimal.rs | 16 ++++++++-------- arrow-cast/src/cast/dictionary.rs | 11 +++++------ arrow-cast/src/cast/mod.rs | 4 ++-- arrow-schema/src/ffi.rs | 7 +++---- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index 94d990b31a8f..e47852a5e3cd 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -26,7 +26,7 @@ pub(crate) trait DecimalCast: Sized { fn from_decimal(n: T) -> Option; - fn from_f64(n: f64) -> Option; + fn from_f64(n: f64) -> Option; } impl DecimalCast for i128 { @@ -42,7 +42,7 @@ impl DecimalCast for i128 { n.to_i128() } - fn from_f64(n: f64) -> Option { + fn from_f64(n: f64) -> Option { n.to_i128() } } @@ -60,7 +60,7 @@ impl DecimalCast for i256 { n.to_i256() } - fn from_f64(n: f64) -> Option { + fn from_f64(n: f64) -> Option { i256::from_f64(n) } } @@ -473,7 +473,7 @@ where Ok(Arc::new(result)) } -pub(crate) fn cast_floating_point_to_decimal( +pub(crate) fn cast_floating_point_to_decimal( array: &PrimitiveArray, precision: u8, scale: i8, @@ -481,15 +481,15 @@ pub(crate) fn cast_floating_point_to_decimal( ) -> Result where ::Native: AsPrimitive, - D: DecimalType + ArrowPrimitiveType, - M: ArrowNativeTypeOp + DecimalCast, + D: DecimalType + ArrowPrimitiveType, + ::Native: DecimalCast, { let mul = 10_f64.powi(scale as i32); if cast_options.safe { array .unary_opt::<_, D>(|v| { - M::from_f64::((mul * v.as_()).round()) + D::Native::from_f64((mul * v.as_()).round()) .filter(|v| D::is_valid_decimal_precision(*v, precision)) }) .with_precision_and_scale(precision, scale) @@ -497,7 +497,7 @@ where } else { array .try_unary::<_, D, _>(|v| { - M::from_f64::((mul * v.as_()).round()) + D::Native::from_f64((mul * v.as_()).round()) .ok_or_else(|| { ArrowError::CastError(format!( "Cannot cast to {}({}, {}). Overflowing on {:?}", diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs index 0fd8997555fb..dda4a93e5ead 100644 --- a/arrow-cast/src/cast/dictionary.rs +++ b/arrow-cast/src/cast/dictionary.rs @@ -214,14 +214,14 @@ pub(crate) fn cast_to_dictionary( UInt16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), UInt32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), UInt64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), - Decimal128(p, s) => pack_decimal_to_dictionary::( + Decimal128(p, s) => pack_decimal_to_dictionary::( array, dict_value_type, p, s, cast_options, ), - Decimal256(p, s) => pack_decimal_to_dictionary::( + Decimal256(p, s) => pack_decimal_to_dictionary::( array, dict_value_type, p, @@ -329,7 +329,7 @@ where Ok(Arc::new(b.finish())) } -pub(crate) fn pack_decimal_to_dictionary( +pub(crate) fn pack_decimal_to_dictionary( array: &dyn Array, dict_value_type: &DataType, precision: u8, @@ -338,15 +338,14 @@ pub(crate) fn pack_decimal_to_dictionary( ) -> Result where K: ArrowDictionaryKeyType, - D: DecimalType + ArrowPrimitiveType, - M: ArrowNativeTypeOp + DecimalCast, + D: DecimalType + ArrowPrimitiveType, { let dict = pack_numeric_to_dictionary::(array, dict_value_type, cast_options)?; let dict = dict .as_dictionary::() .downcast_dict::>() .ok_or_else(|| { - ArrowError::ComputeError(format!("Internal Error: Cannot cast dict to {}", D::PREFIX)) + ArrowError::ComputeError(format!("Internal Error: Cannot cast dict to {}Array", D::PREFIX)) })?; let value = dict.values().clone(); // Set correct precision/scale diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 3572357c2960..7cfed8485655 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -2051,13 +2051,13 @@ where base, cast_options, ), - Float32 => cast_floating_point_to_decimal::<_, D, _>( + Float32 => cast_floating_point_to_decimal::<_, D>( array.as_primitive::(), *precision, *scale, cast_options, ), - Float64 => cast_floating_point_to_decimal::<_, D, _>( + Float64 => cast_floating_point_to_decimal::<_, D>( array.as_primitive::(), *precision, *scale, diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs index f6fcc8f275b3..d86fb66190b4 100644 --- a/arrow-schema/src/ffi.rs +++ b/arrow-schema/src/ffi.rs @@ -520,10 +520,9 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { "The decimal type requires an integer scale".to_string(), ) })?; - let parsed_bits = bits.parse::().unwrap_or(0); - match parsed_bits { - 128 => DataType::Decimal128(parsed_precision, parsed_scale), - 256 => DataType::Decimal256(parsed_precision, parsed_scale), + match *bits { + "128" => DataType::Decimal128(parsed_precision, parsed_scale), + "256" => DataType::Decimal256(parsed_precision, parsed_scale), _ => return Err(ArrowError::CDataInterface("Only 128- and 256- bit wide decimals are supported in the Rust implementation".to_string())), } } From 0c90b05f2b2f965453cb9107b715c60c992e20bc Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Thu, 6 Feb 2025 15:24:28 -0800 Subject: [PATCH 4/6] Fixed format problem --- arrow-cast/src/cast/dictionary.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs index dda4a93e5ead..fd3be261592c 100644 --- a/arrow-cast/src/cast/dictionary.rs +++ b/arrow-cast/src/cast/dictionary.rs @@ -345,7 +345,10 @@ where .as_dictionary::() .downcast_dict::>() .ok_or_else(|| { - ArrowError::ComputeError(format!("Internal Error: Cannot cast dict to {}Array", D::PREFIX)) + ArrowError::ComputeError(format!( + "Internal Error: Cannot cast dict to {}Array", + D::PREFIX + )) })?; let value = dict.values().clone(); // Set correct precision/scale From b6c864d524f5c5fc3863f5be2a14a3538fc843a9 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Thu, 6 Feb 2025 15:44:54 -0800 Subject: [PATCH 5/6] Fixed logical merge conflicts --- arrow-cast/src/cast/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 7cfed8485655..d44117c6a52d 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -9875,7 +9875,7 @@ mod tests { #[test] fn test_decimal_to_decimal_same_scale() { let array = vec![Some(520)]; - let array = create_decimal_array(array, 4, 2).unwrap(); + let array = create_decimal128_array(array, 4, 2).unwrap(); let input_type = DataType::Decimal128(4, 2); let output_type = DataType::Decimal128(3, 2); assert!(can_cast_types(&input_type, &output_type)); @@ -9893,11 +9893,11 @@ mod tests { // Cast 0 of decimal(3, 0) type to decimal(2, 0) assert_eq!( &cast( - &create_decimal_array(vec![Some(0)], 3, 0).unwrap(), + &create_decimal128_array(vec![Some(0)], 3, 0).unwrap(), &DataType::Decimal128(2, 0) ) .unwrap(), - &(Arc::new(create_decimal_array(vec![Some(0)], 2, 0).unwrap()) as ArrayRef) + &(Arc::new(create_decimal128_array(vec![Some(0)], 2, 0).unwrap()) as ArrayRef) ); } From a665e543ae41a6aa491d8ec061570c11f71e1ed6 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Fri, 7 Feb 2025 09:22:58 -0800 Subject: [PATCH 6/6] PR feedback --- arrow-cast/src/cast/mod.rs | 173 +++++++++++++++++++------------------ arrow-ord/src/sort.rs | 5 +- 2 files changed, 88 insertions(+), 90 deletions(-) diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index d44117c6a52d..53069c1a1638 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -2770,91 +2770,92 @@ mod tests { ); } - macro_rules! generate_decimal_to_numeric_cast_test_case { - ($INPUT_ARRAY: expr) => { - // u8 - generate_cast_test_case!( - $INPUT_ARRAY, - UInt8Array, - &DataType::UInt8, - vec![Some(1_u8), Some(2_u8), Some(3_u8), None, Some(5_u8)] - ); - // u16 - generate_cast_test_case!( - $INPUT_ARRAY, - UInt16Array, - &DataType::UInt16, - vec![Some(1_u16), Some(2_u16), Some(3_u16), None, Some(5_u16)] - ); - // u32 - generate_cast_test_case!( - $INPUT_ARRAY, - UInt32Array, - &DataType::UInt32, - vec![Some(1_u32), Some(2_u32), Some(3_u32), None, Some(5_u32)] - ); - // u64 - generate_cast_test_case!( - $INPUT_ARRAY, - UInt64Array, - &DataType::UInt64, - vec![Some(1_u64), Some(2_u64), Some(3_u64), None, Some(5_u64)] - ); - // i8 - generate_cast_test_case!( - $INPUT_ARRAY, - Int8Array, - &DataType::Int8, - vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] - ); - // i16 - generate_cast_test_case!( - $INPUT_ARRAY, - Int16Array, - &DataType::Int16, - vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] - ); - // i32 - generate_cast_test_case!( - $INPUT_ARRAY, - Int32Array, - &DataType::Int32, - vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] - ); - // i64 - generate_cast_test_case!( - $INPUT_ARRAY, - Int64Array, - &DataType::Int64, - vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] - ); - // f32 - generate_cast_test_case!( - $INPUT_ARRAY, - Float32Array, - &DataType::Float32, - vec![ - Some(1.25_f32), - Some(2.25_f32), - Some(3.25_f32), - None, - Some(5.25_f32) - ] - ); - // f64 - generate_cast_test_case!( - $INPUT_ARRAY, - Float64Array, - &DataType::Float64, - vec![ - Some(1.25_f64), - Some(2.25_f64), - Some(3.25_f64), - None, - Some(5.25_f64) - ] - ); - }; + fn generate_decimal_to_numeric_cast_test_case(array: &PrimitiveArray) + where + T: ArrowPrimitiveType + DecimalType, + { + // u8 + generate_cast_test_case!( + array, + UInt8Array, + &DataType::UInt8, + vec![Some(1_u8), Some(2_u8), Some(3_u8), None, Some(5_u8)] + ); + // u16 + generate_cast_test_case!( + array, + UInt16Array, + &DataType::UInt16, + vec![Some(1_u16), Some(2_u16), Some(3_u16), None, Some(5_u16)] + ); + // u32 + generate_cast_test_case!( + array, + UInt32Array, + &DataType::UInt32, + vec![Some(1_u32), Some(2_u32), Some(3_u32), None, Some(5_u32)] + ); + // u64 + generate_cast_test_case!( + array, + UInt64Array, + &DataType::UInt64, + vec![Some(1_u64), Some(2_u64), Some(3_u64), None, Some(5_u64)] + ); + // i8 + generate_cast_test_case!( + array, + Int8Array, + &DataType::Int8, + vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] + ); + // i16 + generate_cast_test_case!( + array, + Int16Array, + &DataType::Int16, + vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] + ); + // i32 + generate_cast_test_case!( + array, + Int32Array, + &DataType::Int32, + vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] + ); + // i64 + generate_cast_test_case!( + array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f32 + generate_cast_test_case!( + array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32) + ] + ); + // f64 + generate_cast_test_case!( + array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64) + ] + ); } #[test] @@ -2862,7 +2863,7 @@ mod tests { let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; let array = create_decimal128_array(value_array, 38, 2).unwrap(); - generate_decimal_to_numeric_cast_test_case!(&array); + generate_decimal_to_numeric_cast_test_case(&array); // overflow test: out of range of max u8 let value_array: Vec> = vec![Some(51300)]; diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index fd3603b3729f..7894999157c7 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -801,10 +801,7 @@ mod tests { scale: i8, ) -> PrimitiveArray { data.into_iter() - .map(|x| match x { - None => None, - Some(y) => T::Native::from_usize(y), - }) + .map(|x| x.and_then(T::Native::from_usize)) .collect::>() .with_precision_and_scale(precision, scale) .unwrap()