Skip to content

Commit

Permalink
feat: teach DateTimePartsArray to cast (#1946)
Browse files Browse the repository at this point in the history
I also:
1. Moved all the canonicalization code into canoncial.rs
2. Split the big test into smaller tests and put the smaller tests into
the scrutinized files.
3. Used rstest for parameterized tests.
  • Loading branch information
danking authored Jan 16, 2025
1 parent 156b412 commit 55cf81a
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 177 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions encodings/datetime-parts/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ vortex-error = { workspace = true }
vortex-scalar = { workspace = true }

[dev-dependencies]
rstest = { workspace = true }
vortex-array = { workspace = true, features = ["test-harness"] }
13 changes: 1 addition & 12 deletions encodings/datetime-parts/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@ use vortex_array::stats::StatsSet;
use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity, ValidityVTable};
use vortex_array::variants::{ExtensionArrayTrait, VariantsVTable};
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
use vortex_array::{
impl_encoding, ArrayDType, ArrayData, ArrayLen, ArrayTrait, Canonical, IntoArrayData,
IntoCanonical,
};
use vortex_array::{impl_encoding, ArrayDType, ArrayData, ArrayLen, ArrayTrait, IntoArrayData};
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_bail, VortexExpect as _, VortexResult, VortexUnwrap};

use crate::compute::decode_to_temporal;

impl_encoding!("vortex.datetimeparts", ids::DATE_TIME_PARTS, DateTimeParts);

#[derive(Clone, Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -136,12 +131,6 @@ impl ExtensionArrayTrait for DateTimePartsArray {
}
}

impl IntoCanonical for DateTimePartsArray {
fn into_canonical(self) -> VortexResult<Canonical> {
Ok(Canonical::Extension(decode_to_temporal(&self)?.into()))
}
}

impl ValidityVTable<DateTimePartsArray> for DateTimePartsEncoding {
fn is_valid(&self, array: &DateTimePartsArray, index: usize) -> bool {
array.validity().is_valid(index)
Expand Down
146 changes: 146 additions & 0 deletions encodings/datetime-parts/src/canonical.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use vortex_array::array::{PrimitiveArray, TemporalArray};
use vortex_array::compute::try_cast;
use vortex_array::{
ArrayDType, Canonical, IntoArrayData as _, IntoArrayVariant as _, IntoCanonical,
};
use vortex_buffer::BufferMut;
use vortex_datetime_dtype::{TemporalMetadata, TimeUnit};
use vortex_dtype::Nullability::NonNullable;
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_bail, VortexExpect as _, VortexResult};
use vortex_scalar::PrimitiveScalar;

use crate::DateTimePartsArray;

impl IntoCanonical for DateTimePartsArray {
fn into_canonical(self) -> VortexResult<Canonical> {
Ok(Canonical::Extension(decode_to_temporal(&self)?.into()))
}
}

/// Decode an [ArrayData] into a [TemporalArray].
///
/// Enforces that the passed array is actually a [DateTimePartsArray] with proper metadata.
pub fn decode_to_temporal(array: &DateTimePartsArray) -> VortexResult<TemporalArray> {
let DType::Extension(ext) = array.dtype().clone() else {
vortex_bail!(ComputeError: "expected dtype to be DType::Extension variant")
};

let Ok(temporal_metadata) = TemporalMetadata::try_from(ext.as_ref()) else {
vortex_bail!(ComputeError: "must decode TemporalMetadata from extension metadata");
};

let divisor = match temporal_metadata.time_unit() {
TimeUnit::Ns => 1_000_000_000,
TimeUnit::Us => 1_000_000,
TimeUnit::Ms => 1_000,
TimeUnit::S => 1,
TimeUnit::D => vortex_bail!(InvalidArgument: "cannot decode into TimeUnit::D"),
};

let days_buf = try_cast(
array.days(),
&DType::Primitive(PType::I64, array.dtype().nullability()),
)?
.into_primitive()?;

// We start with the days component, which is always present.
// And then add the seconds and subseconds components.
// We split this into separate passes because often the seconds and/org subsecond components
// are constant.
let mut values: BufferMut<i64> = days_buf
.into_buffer_mut::<i64>()
.map_each(|d| d * 86_400 * divisor);

if let Some(seconds) = array.seconds().as_constant() {
let seconds =
PrimitiveScalar::try_from(&seconds.cast(&DType::Primitive(PType::I64, NonNullable))?)?
.typed_value::<i64>()
.vortex_expect("non-nullable");
let seconds = seconds * divisor;
for v in values.iter_mut() {
*v += seconds;
}
} else {
let seconds_buf = try_cast(array.seconds(), &DType::Primitive(PType::U32, NonNullable))?
.into_primitive()?;
for (v, second) in values.iter_mut().zip(seconds_buf.as_slice::<u32>()) {
*v += (*second as i64) * divisor;
}
}

if let Some(subseconds) = array.subsecond().as_constant() {
let subseconds = PrimitiveScalar::try_from(
&subseconds.cast(&DType::Primitive(PType::I64, NonNullable))?,
)?
.typed_value::<i64>()
.vortex_expect("non-nullable");
for v in values.iter_mut() {
*v += subseconds;
}
} else {
let subsecond_buf = try_cast(
array.subsecond(),
&DType::Primitive(PType::I64, NonNullable),
)?
.into_primitive()?;
for (v, subsecond) in values.iter_mut().zip(subsecond_buf.as_slice::<i64>()) {
*v += *subsecond;
}
}

Ok(TemporalArray::new_timestamp(
PrimitiveArray::new(values.freeze(), array.validity()).into_array(),
temporal_metadata.time_unit(),
temporal_metadata.time_zone().map(ToString::to_string),
))
}

#[cfg(test)]
mod test {
use rstest::rstest;
use vortex_array::array::{PrimitiveArray, TemporalArray};
use vortex_array::validity::Validity;
use vortex_array::{IntoArrayData as _, IntoArrayVariant};
use vortex_buffer::buffer;
use vortex_datetime_dtype::TimeUnit;

use crate::canonical::decode_to_temporal;
use crate::DateTimePartsArray;

#[rstest]
#[case(Validity::NonNullable)]
#[case(Validity::AllValid)]
#[case(Validity::AllInvalid)]
#[case(Validity::from_iter([true, false, true]))]
fn test_decode_to_temporal(#[case] validity: Validity) {
let milliseconds = PrimitiveArray::new(
buffer![
86_400i64, // element with only day component
86_400i64 + 1000, // element with day + second components
86_400i64 + 1000 + 1, // element with day + second + sub-second components
],
validity.clone(),
);
let date_times = DateTimePartsArray::try_from(TemporalArray::new_timestamp(
milliseconds.clone().into_array(),
TimeUnit::Ms,
Some("UTC".to_string()),
))
.unwrap();

assert_eq!(date_times.validity(), validity);

let primitive_values = decode_to_temporal(&date_times)
.unwrap()
.temporal_values()
.into_primitive()
.unwrap();

assert_eq!(
primitive_values.as_slice::<i64>(),
milliseconds.as_slice::<i64>()
);
assert_eq!(primitive_values.validity(), validity);
}
}
63 changes: 62 additions & 1 deletion encodings/datetime-parts/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use vortex_array::{ArrayDType as _, ArrayData, ArrayLen, IntoArrayData, IntoArra
use vortex_buffer::BufferMut;
use vortex_datetime_dtype::TimeUnit;
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_bail, VortexResult};
use vortex_error::{vortex_bail, VortexError, VortexResult};

use crate::DateTimePartsArray;

pub struct TemporalParts {
pub days: ArrayData,
Expand Down Expand Up @@ -52,3 +54,62 @@ pub fn split_temporal(array: TemporalArray) -> VortexResult<TemporalParts> {
subseconds: subsecond.into_array(),
})
}

impl TryFrom<TemporalArray> for DateTimePartsArray {
type Error = VortexError;

fn try_from(array: TemporalArray) -> Result<Self, Self::Error> {
let ext_dtype = array.ext_dtype();
let TemporalParts {
days,
seconds,
subseconds,
} = split_temporal(array)?;
DateTimePartsArray::try_new(DType::Extension(ext_dtype), days, seconds, subseconds)
}
}

#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_array::array::{PrimitiveArray, TemporalArray};
use vortex_array::validity::Validity;
use vortex_array::{IntoArrayData as _, IntoArrayVariant as _};
use vortex_buffer::buffer;
use vortex_datetime_dtype::TimeUnit;

use crate::{split_temporal, TemporalParts};

#[rstest]
#[case(Validity::NonNullable)]
#[case(Validity::AllValid)]
#[case(Validity::AllInvalid)]
#[case(Validity::from_iter([true, false, true]))]
fn test_split_temporal(#[case] validity: Validity) {
let milliseconds = PrimitiveArray::new(
buffer![
86_400i64, // element with only day component
86_400i64 + 1000, // element with day + second components
86_400i64 + 1000 + 1, // element with day + second + sub-second components
],
validity.clone(),
)
.into_array();
let temporal_array =
TemporalArray::new_timestamp(milliseconds, TimeUnit::Ms, Some("UTC".to_string()));
let TemporalParts {
days,
seconds,
subseconds,
} = split_temporal(temporal_array).unwrap();
assert_eq!(days.into_primitive().unwrap().validity(), validity);
assert_eq!(
seconds.into_primitive().unwrap().validity(),
Validity::NonNullable
);
assert_eq!(
subseconds.into_primitive().unwrap().validity(),
Validity::NonNullable
);
}
}
103 changes: 103 additions & 0 deletions encodings/datetime-parts/src/compute/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use vortex_array::compute::{try_cast, CastFn};
use vortex_array::{ArrayDType, ArrayData, IntoArrayData};
use vortex_dtype::DType;
use vortex_error::{vortex_bail, VortexResult};

use crate::{DateTimePartsArray, DateTimePartsEncoding};

impl CastFn<DateTimePartsArray> for DateTimePartsEncoding {
fn cast(&self, array: &DateTimePartsArray, dtype: &DType) -> VortexResult<ArrayData> {
if !array.dtype().eq_ignore_nullability(dtype) {
vortex_bail!("cannot cast from {} to {}", array.dtype(), dtype);
};

Ok(DateTimePartsArray::try_new(
dtype.clone(),
try_cast(
array.days().as_ref(),
&array.days().dtype().with_nullability(dtype.nullability()),
)?,
array.seconds(),
array.subsecond(),
)?
.into_array())
}
}

#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_array::array::{PrimitiveArray, TemporalArray};
use vortex_array::compute::try_cast;
use vortex_array::validity::Validity;
use vortex_array::{ArrayDType as _, ArrayData, IntoArrayData as _};
use vortex_buffer::buffer;
use vortex_datetime_dtype::TimeUnit;
use vortex_dtype::{DType, Nullability};

use crate::DateTimePartsArray;

fn date_time_array(validity: Validity) -> ArrayData {
DateTimePartsArray::try_from(TemporalArray::new_timestamp(
PrimitiveArray::new(
buffer![
86_400i64, // element with only day component
86_400i64 + 1000, // element with day + second components
86_400i64 + 1000 + 1, // element with day + second + sub-second components
],
validity,
)
.into_array(),
TimeUnit::Ms,
Some("UTC".to_string()),
))
.unwrap()
.into_array()
}

#[rstest]
#[case(Validity::NonNullable, Nullability::Nullable)]
#[case(Validity::AllValid, Nullability::Nullable)]
#[case(Validity::AllInvalid, Nullability::Nullable)]
#[case(Validity::from_iter([true, false, true]), Nullability::Nullable)]
#[case(Validity::NonNullable, Nullability::NonNullable)]
#[case(Validity::AllValid, Nullability::NonNullable)]
#[case(Validity::from_iter([true, true, true]), Nullability::Nullable)]
fn test_cast_to_compatibile_nullability(
#[case] validity: Validity,
#[case] cast_to_nullability: Nullability,
) {
let array = date_time_array(validity);
let new_dtype = array.dtype().with_nullability(cast_to_nullability);
let result = try_cast(&array, &new_dtype);
assert!(result.is_ok(), "{:?}", result);
assert_eq!(result.unwrap().dtype(), &new_dtype);
}

#[rstest]
#[case(Validity::AllInvalid)]
#[case(Validity::from_iter([true, false, true]))]
fn test_bad_cast_fails(#[case] validity: Validity) {
let array = date_time_array(validity);
let result = try_cast(&array, &DType::Bool(Nullability::NonNullable));
assert!(
result
.as_ref()
.is_err_and(|err| err.to_string().contains("cannot cast from")),
"{:?}",
result
);

let result = try_cast(
&array,
&array.dtype().with_nullability(Nullability::NonNullable),
);
assert!(
result.as_ref().is_err_and(|err| err
.to_string()
.contains("invalid cast from nullable to non-nullable")),
"{:?}",
result
);
}
}
Loading

0 comments on commit 55cf81a

Please sign in to comment.